mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Compare commits
1 Commits
changelog-
...
jay.wren-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ceac2b078 |
22
.github/ISSUE_TEMPLATE/config.yml
vendored
22
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,21 +1,13 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: 💨 Performance Issues
|
||||
url: https://github.com/slackhq/nebula/discussions/new/choose
|
||||
about: 'We ask that you create a discussion instead of an issue for performance-related questions. This allows us to have a more open conversation about the issue and helps us to better understand the problem.'
|
||||
|
||||
- name: 📄 Documentation Issues
|
||||
url: https://github.com/definednet/nebula-docs
|
||||
about: "If you've found an issue with the website documentation, please file it in the nebula-docs repository."
|
||||
|
||||
- name: 📱 Mobile Nebula Issues
|
||||
url: https://github.com/definednet/mobile_nebula
|
||||
about: "If you're using the mobile Nebula app and have found an issue, please file it in the mobile_nebula repository."
|
||||
|
||||
- name: 📘 Documentation
|
||||
url: https://nebula.defined.net/docs/
|
||||
about: 'The documentation is the best place to start if you are new to Nebula.'
|
||||
about: Review documentation.
|
||||
|
||||
- name: 💁 Support/Chat
|
||||
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
|
||||
about: 'For faster support, join us on Slack for assistance!'
|
||||
url: https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU
|
||||
about: 'This issue tracker is not for support questions. Join us on Slack for assistance!'
|
||||
|
||||
- name: 📱 Mobile Nebula
|
||||
url: https://github.com/definednet/mobile_nebula
|
||||
about: 'This issue tracker is not for mobile support. Try the Mobile Nebula repo instead!'
|
||||
|
||||
11
.github/pull_request_template.md
vendored
11
.github/pull_request_template.md
vendored
@@ -1,11 +0,0 @@
|
||||
<!--
|
||||
Thank you for taking the time to submit a pull request!
|
||||
|
||||
Please be sure to provide a clear description of what you're trying to achieve with the change.
|
||||
|
||||
- If you're submitting a new feature, please explain how to use it and document any new config options in the example config.
|
||||
- If you're submitting a bugfix, please link the related issue or describe the circumstances surrounding the issue.
|
||||
- If you're changing a default, explain why you believe the new default is appropriate for most users.
|
||||
|
||||
P.S. If you're only updating the README or other docs, please file a pull request here instead: https://github.com/DefinedNet/nebula-docs
|
||||
-->
|
||||
6
.github/workflows/gofmt.yml
vendored
6
.github/workflows/gofmt.yml
vendored
@@ -14,11 +14,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Install goimports
|
||||
|
||||
32
.github/workflows/release.yml
vendored
32
.github/workflows/release.yml
vendored
@@ -10,11 +10,11 @@ jobs:
|
||||
name: Build Linux/BSD All
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
mv build/*.tar.gz release
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-latest
|
||||
path: release
|
||||
@@ -33,11 +33,11 @@ jobs:
|
||||
name: Build Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
mv dist\windows\wintun build\dist\windows\
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-latest
|
||||
path: build
|
||||
@@ -66,11 +66,11 @@ jobs:
|
||||
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Import certificates
|
||||
@@ -104,7 +104,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: darwin-latest
|
||||
path: ./release/*
|
||||
@@ -124,11 +124,11 @@ jobs:
|
||||
# be overwritten
|
||||
- name: Checkout code
|
||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download artifacts
|
||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||
uses: actions/download-artifact@v6
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: linux-latest
|
||||
path: artifacts
|
||||
@@ -160,10 +160,10 @@ jobs:
|
||||
needs: [build-linux, build-darwin, build-windows]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v6
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
|
||||
6
.github/workflows/smoke-extra.yml
vendored
6
.github/workflows/smoke-extra.yml
vendored
@@ -20,11 +20,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version-file: 'go.mod'
|
||||
check-latest: true
|
||||
|
||||
- name: add hashicorp source
|
||||
|
||||
6
.github/workflows/smoke.yml
vendored
6
.github/workflows/smoke.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: build
|
||||
|
||||
36
.github/workflows/test.yml
vendored
36
.github/workflows/test.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -32,9 +32,9 @@ jobs:
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: v2.5
|
||||
version: v2.0
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Build test mobile
|
||||
run: make build-test-mobile
|
||||
|
||||
- uses: actions/upload-artifact@v5
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: e2e packet flow linux-latest
|
||||
path: e2e/mermaid/linux-latest
|
||||
@@ -56,11 +56,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -77,11 +77,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.22'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -98,11 +98,11 @@ jobs:
|
||||
os: [windows-latest, macos-latest]
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build nebula
|
||||
@@ -115,9 +115,9 @@ jobs:
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: v2.5
|
||||
version: v2.0
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
- name: End 2 end
|
||||
run: make e2evv
|
||||
|
||||
- uses: actions/upload-artifact@v5
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: e2e packet flow ${{ matrix.os }}
|
||||
path: e2e/mermaid/${{ matrix.os }}
|
||||
|
||||
60
CHANGELOG.md
60
CHANGELOG.md
@@ -7,64 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [1.10.0] - ????
|
||||
|
||||
### Added
|
||||
|
||||
- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153)
|
||||
- ASN.1 based v2 nebula certificates with support for ipv6 and multiple ip addresses.
|
||||
Certificates now have a unified interface for external implementations. (#1212, #1216, #1345)
|
||||
**TODO: External documentation link!**
|
||||
- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
|
||||
- Add ECMP support for `unsafe_routes`. (#1332)
|
||||
|
||||
### Changed
|
||||
|
||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
||||
`local_cidr` field. This is almost always the intended behavior. This flag is
|
||||
deprecated and will be removed in a future release. (#1373)
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fix moving a udp address from one vpn address to another in the `static_host_map`
|
||||
which could cause rapid re-handshaking with an incorrect remote. (#1259)
|
||||
- Improve smoke tests in environments where the docker network is not the default. (#1347)
|
||||
|
||||
## [1.9.7] - 2025-10-10
|
||||
|
||||
### Security
|
||||
|
||||
- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
|
||||
certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
|
||||
|
||||
### Changed
|
||||
|
||||
- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
|
||||
- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1543)
|
||||
|
||||
## [1.9.6] - 2025-7-15
|
||||
|
||||
### Added
|
||||
|
||||
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
|
||||
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
|
||||
- Fix Windows freeze due to ICMP error handling (#1412)
|
||||
- Fix relay migration panic (#1403)
|
||||
|
||||
## [1.9.5] - 2024-12-05
|
||||
|
||||
### Added
|
||||
|
||||
- Gracefully ignore v2 certificates. (#1282)
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
|
||||
deprecated and will be removed in a future release.
|
||||
|
||||
## [1.9.4] - 2024-09-09
|
||||
|
||||
@@ -723,11 +671,7 @@ created.)
|
||||
|
||||
- Initial public release.
|
||||
|
||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD
|
||||
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
|
||||
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
|
||||
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
|
||||
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
|
||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
|
||||
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
|
||||
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
|
||||
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2
|
||||
|
||||
55
README.md
55
README.md
@@ -4,7 +4,7 @@ It lets you seamlessly connect computers anywhere in the world. Nebula is portab
|
||||
It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
|
||||
|
||||
Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
|
||||
and tunneling.
|
||||
and tunneling, and each of those individual pieces existed before Nebula in various forms.
|
||||
What makes Nebula different to existing offerings is that it brings all of these ideas together,
|
||||
resulting in a sum that is greater than its individual parts.
|
||||
|
||||
@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
|
||||
|
||||
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
||||
|
||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
|
||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
|
||||
|
||||
## Supported Platforms
|
||||
|
||||
@@ -28,33 +28,33 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
|
||||
#### Distribution Packages
|
||||
|
||||
- [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/)
|
||||
```sh
|
||||
sudo pacman -S nebula
|
||||
```
|
||||
$ sudo pacman -S nebula
|
||||
```
|
||||
|
||||
- [Fedora Linux](https://src.fedoraproject.org/rpms/nebula)
|
||||
```sh
|
||||
sudo dnf install nebula
|
||||
```
|
||||
$ sudo dnf install nebula
|
||||
```
|
||||
|
||||
- [Debian Linux](https://packages.debian.org/source/stable/nebula)
|
||||
```sh
|
||||
sudo apt install nebula
|
||||
```
|
||||
$ sudo apt install nebula
|
||||
```
|
||||
|
||||
- [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula)
|
||||
```sh
|
||||
sudo apk add nebula
|
||||
```
|
||||
$ sudo apk add nebula
|
||||
```
|
||||
|
||||
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
|
||||
```sh
|
||||
brew install nebula
|
||||
```
|
||||
$ brew install nebula
|
||||
```
|
||||
|
||||
- [Docker](https://hub.docker.com/r/nebulaoss/nebula)
|
||||
```sh
|
||||
docker pull nebulaoss/nebula
|
||||
```
|
||||
$ docker pull nebulaoss/nebula
|
||||
```
|
||||
|
||||
#### Mobile
|
||||
@@ -64,10 +64,10 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
|
||||
|
||||
## Technical Overview
|
||||
|
||||
Nebula is a mutually authenticated peer-to-peer software-defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
|
||||
Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
|
||||
Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups.
|
||||
Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes.
|
||||
Discovery nodes (aka lighthouses) allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
|
||||
Discovery nodes allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
|
||||
Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme.
|
||||
|
||||
Nebula uses Elliptic-curve Diffie-Hellman (`ECDH`) key exchange and `AES-256-GCM` in its default configuration.
|
||||
@@ -82,34 +82,28 @@ To set up a Nebula network, you'll need:
|
||||
|
||||
#### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse.
|
||||
|
||||
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $6/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
|
||||
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $5/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
|
||||
|
||||
Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
|
||||
|
||||
|
||||
#### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network.
|
||||
|
||||
```sh
|
||||
```
|
||||
./nebula-cert ca -name "Myorganization, Inc"
|
||||
```
|
||||
|
||||
This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
|
||||
|
||||
**Be aware!** By default, certificate authorities have a 1-year lifetime before expiration. See [this guide](https://nebula.defined.net/docs/guides/rotating-certificate-authority/) for details on rotating a CA.
|
||||
|
||||
#### 4. Nebula host keys and certificates generated from that certificate authority
|
||||
|
||||
This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network.
|
||||
```sh
|
||||
```
|
||||
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
|
||||
./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh"
|
||||
./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers"
|
||||
./nebula-cert sign -name "host3" -ip "192.168.100.10/24"
|
||||
```
|
||||
|
||||
By default, host certificates will expire 1 second before the CA expires. Use the `-duration` flag to specify a shorter lifetime.
|
||||
|
||||
#### 5. Configuration files for each host
|
||||
|
||||
Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yml).
|
||||
|
||||
* On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set.
|
||||
@@ -124,13 +118,10 @@ For each host, copy the nebula binary to the host, along with `config.yml` from
|
||||
**DO NOT COPY `ca.key` TO INDIVIDUAL NODES.**
|
||||
|
||||
#### 7. Run nebula on each host
|
||||
|
||||
```sh
|
||||
```
|
||||
./nebula -config /path/to/config.yml
|
||||
```
|
||||
|
||||
For more detailed instructions, [find the full documentation here](https://nebula.defined.net/docs/).
|
||||
|
||||
## Building Nebula from source
|
||||
|
||||
Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory.
|
||||
@@ -149,10 +140,8 @@ The default curve used for cryptographic handshakes and signatures is Curve25519
|
||||
|
||||
In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
|
||||
|
||||
```sh
|
||||
make bin-boringcrypto
|
||||
make release-boringcrypto
|
||||
```
|
||||
|
||||
This is not the recommended default deployment, but may be useful based on your compliance requirements.
|
||||
|
||||
@@ -160,3 +149,5 @@ This is not the recommended default deployment, but may be useful based on your
|
||||
|
||||
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.
|
||||
|
||||
|
||||
|
||||
|
||||
107
bits.go
107
bits.go
@@ -9,13 +9,14 @@ type Bits struct {
|
||||
length uint64
|
||||
current uint64
|
||||
bits []bool
|
||||
firstSeen bool
|
||||
lostCounter metrics.Counter
|
||||
dupeCounter metrics.Counter
|
||||
outOfWindowCounter metrics.Counter
|
||||
}
|
||||
|
||||
func NewBits(bits uint64) *Bits {
|
||||
b := &Bits{
|
||||
return &Bits{
|
||||
length: bits,
|
||||
bits: make([]bool, bits, bits),
|
||||
current: 0,
|
||||
@@ -23,37 +24,34 @@ func NewBits(bits uint64) *Bits {
|
||||
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
||||
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
|
||||
}
|
||||
|
||||
// There is no counter value 0, mark it to avoid counting a lost packet later.
|
||||
b.bits[0] = true
|
||||
b.current = 0
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
||||
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
||||
// If i is the next number, return true.
|
||||
if i > b.current {
|
||||
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
|
||||
return true
|
||||
}
|
||||
|
||||
// If i is within the window, check if it's been set already.
|
||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||
// If i is within the window, check if it's been set already. The first window will fail this check
|
||||
if i > b.current-b.length {
|
||||
return !b.bits[i%b.length]
|
||||
}
|
||||
|
||||
// If i is within the first window
|
||||
if i < b.length {
|
||||
return !b.bits[i%b.length]
|
||||
}
|
||||
|
||||
// Not within the window
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
// If i is the next number, return true and update current.
|
||||
if i == b.current+1 {
|
||||
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
||||
// The very first window can only be tracked as lost once we are on the 2nd window or greater
|
||||
if b.bits[i%b.length] == false && i > b.length {
|
||||
// Report missed packets, we can only understand what was missed after the first window has been gone through
|
||||
if i > b.length && b.bits[i%b.length] == false {
|
||||
b.lostCounter.Inc(1)
|
||||
}
|
||||
b.bits[i%b.length] = true
|
||||
@@ -61,32 +59,61 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// If i is a jump, adjust the window, record lost, update current, and return true
|
||||
if i > b.current {
|
||||
lost := int64(0)
|
||||
// Zero out the bits between the current and the new counter value, limited by the window size,
|
||||
// since the window is shifting
|
||||
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
|
||||
if b.bits[n%b.length] == false && n > b.length {
|
||||
lost++
|
||||
}
|
||||
// If i packet is greater than current but less than the maximum length of our bitmap,
|
||||
// flip everything in between to false and move ahead.
|
||||
if i > b.current && i < b.current+b.length {
|
||||
// In between current and i need to be zero'd to allow those packets to come in later
|
||||
for n := b.current + 1; n < i; n++ {
|
||||
b.bits[n%b.length] = false
|
||||
}
|
||||
|
||||
// Only record any skipped packets as a result of the window moving further than the window length
|
||||
// Any loss within the new window will be accounted for in future calls
|
||||
lost += max(0, int64(i-b.current-b.length))
|
||||
b.bits[i%b.length] = true
|
||||
b.current = i
|
||||
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
|
||||
return true
|
||||
}
|
||||
|
||||
// If i is greater than the delta between current and the total length of our bitmap,
|
||||
// just flip everything in the map and move ahead.
|
||||
if i >= b.current+b.length {
|
||||
// The current window loss will be accounted for later, only record the jump as loss up until then
|
||||
lost := maxInt64(0, int64(i-b.current-b.length))
|
||||
//TODO: explain this
|
||||
if b.current == 0 {
|
||||
lost++
|
||||
}
|
||||
|
||||
for n := range b.bits {
|
||||
// Don't want to count the first window as a loss
|
||||
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
|
||||
//if b.bits[n] == false {
|
||||
// lost++
|
||||
//}
|
||||
b.bits[n] = false
|
||||
}
|
||||
|
||||
b.lostCounter.Inc(lost)
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
|
||||
Debug("Receive window")
|
||||
}
|
||||
b.bits[i%b.length] = true
|
||||
b.current = i
|
||||
return true
|
||||
}
|
||||
|
||||
// If i is within the current window but below the current counter,
|
||||
// Check to see if it's a duplicate
|
||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||
if b.current == i || b.bits[i%b.length] == true {
|
||||
// Allow for the 0 packet to come in within the first window
|
||||
if i == 0 && b.firstSeen == false && b.current < b.length {
|
||||
b.firstSeen = true
|
||||
b.bits[i%b.length] = true
|
||||
return true
|
||||
}
|
||||
|
||||
// If i is within the window of current minus length (the total pat window size),
|
||||
// allow it and flip to true but to NOT change current. We also have to account for the first window
|
||||
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
|
||||
if b.current == i {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
||||
Debug("Receive window")
|
||||
@@ -95,8 +122,18 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if b.bits[i%b.length] == true {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
|
||||
Debug("Receive window")
|
||||
}
|
||||
b.dupeCounter.Inc(1)
|
||||
return false
|
||||
}
|
||||
|
||||
b.bits[i%b.length] = true
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// In all other cases, fail and don't change current.
|
||||
@@ -110,3 +147,11 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func maxInt64(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
109
bits_test.go
109
bits_test.go
@@ -15,41 +15,48 @@ func TestBits(t *testing.T) {
|
||||
assert.Len(t, b.bits, 10)
|
||||
|
||||
// This is initialized to zero - receive one. This should work.
|
||||
|
||||
assert.True(t, b.Check(l, 1))
|
||||
assert.True(t, b.Update(l, 1))
|
||||
u := b.Update(l, 1)
|
||||
assert.True(t, u)
|
||||
assert.EqualValues(t, 1, b.current)
|
||||
g := []bool{true, true, false, false, false, false, false, false, false, false}
|
||||
g := []bool{false, true, false, false, false, false, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Receive two
|
||||
assert.True(t, b.Check(l, 2))
|
||||
assert.True(t, b.Update(l, 2))
|
||||
u = b.Update(l, 2)
|
||||
assert.True(t, u)
|
||||
assert.EqualValues(t, 2, b.current)
|
||||
g = []bool{true, true, true, false, false, false, false, false, false, false}
|
||||
g = []bool{false, true, true, false, false, false, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Receive two again - it will fail
|
||||
assert.False(t, b.Check(l, 2))
|
||||
assert.False(t, b.Update(l, 2))
|
||||
u = b.Update(l, 2)
|
||||
assert.False(t, u)
|
||||
assert.EqualValues(t, 2, b.current)
|
||||
|
||||
// Jump ahead to 15, which should clear everything and set the 6th element
|
||||
assert.True(t, b.Check(l, 15))
|
||||
assert.True(t, b.Update(l, 15))
|
||||
u = b.Update(l, 15)
|
||||
assert.True(t, u)
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Mark 14, which is allowed because it is in the window
|
||||
assert.True(t, b.Check(l, 14))
|
||||
assert.True(t, b.Update(l, 14))
|
||||
u = b.Update(l, 14)
|
||||
assert.True(t, u)
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Mark 5, which is not allowed because it is not in the window
|
||||
assert.False(t, b.Check(l, 5))
|
||||
assert.False(t, b.Update(l, 5))
|
||||
u = b.Update(l, 5)
|
||||
assert.False(t, u)
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
@@ -62,29 +69,10 @@ func TestBits(t *testing.T) {
|
||||
|
||||
// Walk through a few windows in order
|
||||
b = NewBits(10)
|
||||
for i := uint64(1); i <= 100; i++ {
|
||||
for i := uint64(0); i <= 100; i++ {
|
||||
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
||||
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
||||
}
|
||||
|
||||
assert.False(t, b.Check(l, 1), "Out of window check")
|
||||
}
|
||||
|
||||
func TestBitsLargeJumps(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
|
||||
b = NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
|
||||
assert.Equal(t, int64(45), b.lostCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
|
||||
assert.Equal(t, int64(89), b.lostCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
|
||||
assert.Equal(t, int64(188), b.lostCounter.Count())
|
||||
}
|
||||
|
||||
func TestBitsDupeCounter(t *testing.T) {
|
||||
@@ -136,7 +124,8 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||
assert.False(t, b.Update(l, 0))
|
||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||
|
||||
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
||||
//tODO: make sure lostcounter doesn't increase in orderly increment
|
||||
assert.Equal(t, int64(20), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||
}
|
||||
@@ -148,6 +137,8 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
//assert.True(t, b.Update(0))
|
||||
assert.True(t, b.Update(l, 0))
|
||||
assert.True(t, b.Update(l, 20))
|
||||
assert.True(t, b.Update(l, 21))
|
||||
assert.True(t, b.Update(l, 22))
|
||||
@@ -158,7 +149,7 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
assert.True(t, b.Update(l, 27))
|
||||
assert.True(t, b.Update(l, 28))
|
||||
assert.True(t, b.Update(l, 29))
|
||||
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
||||
assert.Equal(t, int64(20), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
|
||||
@@ -167,6 +158,8 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
assert.True(t, b.Update(l, 0))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 9))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
// 10 will set 0 index, 0 was already set, no lost packets
|
||||
@@ -221,62 +214,6 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
}
|
||||
|
||||
func TestBitsLostCounterIssue1(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
assert.True(t, b.Update(l, 4))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 1))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 9))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 2))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 3))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 5))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 6))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 7))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
// assert.True(t, b.Update(l, 8))
|
||||
assert.True(t, b.Update(l, 10))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 11))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(l, 14))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
|
||||
assert.True(t, b.Update(l, 19))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 12))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 13))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 15))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 16))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 17))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 18))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 20))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(l, 21))
|
||||
|
||||
// We missed packet 8 above
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
}
|
||||
|
||||
func BenchmarkBits(b *testing.B) {
|
||||
z := NewBits(10)
|
||||
for n := 0; n < b.N; n++ {
|
||||
|
||||
@@ -84,11 +84,16 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
||||
|
||||
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
||||
|
||||
rawMap, ok := value.(map[string]any)
|
||||
rawMap, ok := value.(map[any]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||
}
|
||||
for rawCIDR, rawValue := range rawMap {
|
||||
for rawKey, rawValue := range rawMap {
|
||||
rawCIDR, ok := rawKey.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||
}
|
||||
|
||||
cidr, err := netip.ParsePrefix(rawCIDR)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||
@@ -124,7 +129,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
|
||||
}
|
||||
|
||||
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
||||
rawMap, ok := raw.(map[string]any)
|
||||
rawMap, ok := raw.(map[any]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid type: %T", raw)
|
||||
}
|
||||
|
||||
@@ -58,9 +58,6 @@ type Certificate interface {
|
||||
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
||||
PublicKey() []byte
|
||||
|
||||
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
|
||||
MarshalPublicKeyPEM() []byte
|
||||
|
||||
// Curve identifies which curve was used for the PublicKey and Signature.
|
||||
Curve() Curve
|
||||
|
||||
@@ -138,7 +135,8 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
|
||||
case Version2:
|
||||
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
||||
default:
|
||||
return nil, ErrUnknownVersion
|
||||
//TODO: CERT-V2 make a static var
|
||||
return nil, fmt.Errorf("unknown certificate version %d", v)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -83,10 +83,6 @@ func (c *certificateV1) PublicKey() []byte {
|
||||
return c.details.publicKey
|
||||
}
|
||||
|
||||
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
|
||||
return marshalCertPublicKeyToPEM(c)
|
||||
}
|
||||
|
||||
func (c *certificateV1) Signature() []byte {
|
||||
return c.signature
|
||||
}
|
||||
@@ -114,10 +110,8 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
|
||||
case Curve_CURVE25519:
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||
hashed := sha256.Sum256(b)
|
||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||
default:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
)
|
||||
|
||||
func TestCertificateV1_Marshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
@@ -62,58 +60,6 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||
}
|
||||
|
||||
func TestCertificateV1_PublicKeyPem(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{},
|
||||
unsafeNetworks: []netip.Prefix{},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
publicKey: pubKey,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
assert.Equal(t, Version1, nc.Version())
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.False(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = true
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
||||
require.NoError(t, err)
|
||||
nc.details.curve = Curve_P256
|
||||
nc.details.publicKey = pubP256Key
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = false
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.False(t, nc.IsCA())
|
||||
}
|
||||
|
||||
func TestCertificateV1_Expired(t *testing.T) {
|
||||
nc := certificateV1{
|
||||
details: detailsV1{
|
||||
|
||||
@@ -114,10 +114,6 @@ func (c *certificateV2) PublicKey() []byte {
|
||||
return c.publicKey
|
||||
}
|
||||
|
||||
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
|
||||
return marshalCertPublicKeyToPEM(c)
|
||||
}
|
||||
|
||||
func (c *certificateV2) Signature() []byte {
|
||||
return c.signature
|
||||
}
|
||||
@@ -153,10 +149,8 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
|
||||
case Curve_CURVE25519:
|
||||
return ed25519.Verify(key, b, c.signature)
|
||||
case Curve_P256:
|
||||
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||
hashed := sha256.Sum256(b)
|
||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||
default:
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
)
|
||||
|
||||
func TestCertificateV2_Marshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||
@@ -76,58 +75,6 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||
}
|
||||
|
||||
func TestCertificateV2_PublicKeyPem(t *testing.T) {
|
||||
t.Parallel()
|
||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
||||
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
name: "testing",
|
||||
networks: []netip.Prefix{},
|
||||
unsafeNetworks: []netip.Prefix{},
|
||||
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||
notBefore: before,
|
||||
notAfter: after,
|
||||
isCA: false,
|
||||
issuer: "1234567890abcedfghij1234567890ab",
|
||||
},
|
||||
publicKey: pubKey,
|
||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||
}
|
||||
|
||||
assert.Equal(t, Version2, nc.Version())
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.False(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = true
|
||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
||||
require.NoError(t, err)
|
||||
nc.curve = Curve_P256
|
||||
nc.publicKey = pubP256Key
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.True(t, nc.IsCA())
|
||||
|
||||
nc.details.isCA = false
|
||||
assert.Equal(t, Curve_P256, nc.Curve())
|
||||
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||
assert.False(t, nc.IsCA())
|
||||
}
|
||||
|
||||
func TestCertificateV2_Expired(t *testing.T) {
|
||||
nc := certificateV2{
|
||||
details: detailsV2{
|
||||
|
||||
@@ -26,21 +26,21 @@ func TestNewArgon2Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
|
||||
passphrase := []byte("DO NOT USE")
|
||||
passphrase := []byte("DO NOT USE THIS KEY")
|
||||
privKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiCPoDfGQiosxNPTbPn5EsMlc2MI
|
||||
c0Bt4oz6gTrFQhX3aBJcimhHKeAuhyTGvllD0Z19fe+DFPcLH3h5VrdjVfIAajg0
|
||||
KrbV3n9UHif/Au5skWmquNJzoW1E4MTdRbvpti6o+WdQ49DxjBFhx0YH8LBqrbPU
|
||||
0BGkUHmIO7daP24=
|
||||
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
|
||||
oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
|
||||
+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
|
||||
qrlJ69wer3ZUHFXA
|
||||
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A key which, once decrypted, is too short
|
||||
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiAVJwdfl3r+eqi/vF6S7OMdpjfo
|
||||
hAzmTCRnr58Su4AqmBJbCv3zleYCEKYJP6UI3S8ekLMGISsgO4hm5leukCCyqT0Z
|
||||
cQ76yrberpzkJKoPLGisX8f+xdy4aXSZl7oEYWQte1+vqbtl/eY9PGZhxUQdcyq7
|
||||
hqzIyrRqfUgVuA==
|
||||
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
|
||||
k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
|
||||
GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
|
||||
rQr3bdH3Oy/WiYU=
|
||||
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner (not encrypted)
|
||||
|
||||
@@ -20,7 +20,6 @@ var (
|
||||
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
||||
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
||||
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
||||
ErrUnknownVersion = errors.New("certificate version unrecognized")
|
||||
|
||||
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
||||
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
||||
|
||||
44
cert/pem.go
44
cert/pem.go
@@ -7,26 +7,19 @@ import (
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
const ( //cert banners
|
||||
const (
|
||||
CertificateBanner = "NEBULA CERTIFICATE"
|
||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||
)
|
||||
|
||||
const ( //key-agreement-key banners
|
||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||
)
|
||||
|
||||
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
|
||||
const ( //signing key banners
|
||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
|
||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||
|
||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||
)
|
||||
|
||||
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
||||
@@ -58,16 +51,6 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
||||
|
||||
}
|
||||
|
||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||
if c.IsCA() {
|
||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||
} else {
|
||||
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
|
||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
@@ -79,19 +62,6 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
|
||||
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
||||
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||
switch curve {
|
||||
case Curve_CURVE25519:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
|
||||
case Curve_P256:
|
||||
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
k, r := pem.Decode(b)
|
||||
if k == nil {
|
||||
@@ -103,7 +73,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
||||
expectedLen = 32
|
||||
curve = Curve_CURVE25519
|
||||
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
|
||||
case P256PublicKeyBanner:
|
||||
// Uncompressed
|
||||
expectedLen = 65
|
||||
curve = Curve_P256
|
||||
|
||||
@@ -177,7 +177,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
}
|
||||
|
||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
@@ -231,7 +230,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
}
|
||||
|
||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
@@ -242,12 +240,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA P256 PUBLIC KEY-----
|
||||
`)
|
||||
oldPubP256Key := []byte(`# A good key
|
||||
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA ECDSA P256 PUBLIC KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
@@ -264,22 +256,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
|
||||
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||
assert.Len(t, k, 32)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_CURVE25519, curve)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Len(t, k, 65)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
||||
assert.Equal(t, Curve_P256, curve)
|
||||
|
||||
// Success test case
|
||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||
assert.Len(t, k, 65)
|
||||
|
||||
12
cert/sign.go
12
cert/sign.go
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
@@ -54,10 +55,15 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
|
||||
}
|
||||
return t.SignWith(signer, curve, sp)
|
||||
case Curve_P256:
|
||||
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
pk := &ecdsa.PrivateKey{
|
||||
PublicKey: ecdsa.PublicKey{
|
||||
Curve: elliptic.P256(),
|
||||
},
|
||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
||||
D: new(big.Int).SetBytes(key),
|
||||
}
|
||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
||||
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
|
||||
sp := func(certBytes []byte) ([]byte, error) {
|
||||
// We need to hash first for ECDSA
|
||||
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
||||
|
||||
@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||
}
|
||||
|
||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
||||
nc := &cert.TBSCertificate{
|
||||
Version: v,
|
||||
Curve: c.Curve(),
|
||||
Name: c.Name(),
|
||||
Networks: c.Networks(),
|
||||
UnsafeNetworks: c.UnsafeNetworks(),
|
||||
Groups: c.Groups(),
|
||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
||||
PublicKey: c.PublicKey(),
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pem
|
||||
}
|
||||
|
||||
func X25519Keypair() ([]byte, []byte) {
|
||||
privkey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||
|
||||
@@ -173,8 +173,6 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
|
||||
var passphrase []byte
|
||||
if !isP11 && *cf.encryption {
|
||||
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
||||
if len(passphrase) == 0 {
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
@@ -194,7 +192,6 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
||||
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var curve cert.Curve
|
||||
var pub, rawPriv []byte
|
||||
|
||||
@@ -171,17 +171,6 @@ func Test_ca(t *testing.T) {
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test encrypted key with passphrase environment variable
|
||||
os.Remove(keyF.Name())
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
||||
require.NoError(t, ca(args, ob, eb, testpw))
|
||||
assert.Empty(t, eb.String())
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||
|
||||
// read encrypted key file and verify default params
|
||||
rb, _ = os.ReadFile(keyF.Name())
|
||||
k, _ := pem.Decode(rb)
|
||||
|
||||
@@ -5,28 +5,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A version string that can be set with
|
||||
//
|
||||
// -ldflags "-X main.Build=SOMEVERSION"
|
||||
//
|
||||
// at compile-time.
|
||||
var Build string
|
||||
|
||||
func init() {
|
||||
if Build == "" {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||
}
|
||||
}
|
||||
|
||||
type helpError struct {
|
||||
s string
|
||||
}
|
||||
|
||||
@@ -116,10 +116,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
// naively attempt to decode the private key as though it is not encrypted
|
||||
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
||||
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
|
||||
var passphrase []byte
|
||||
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
||||
if len(passphrase) == 0 {
|
||||
// ask for a passphrase until we get one
|
||||
var passphrase []byte
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
@@ -137,7 +135,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
||||
if len(passphrase) == 0 {
|
||||
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
|
||||
}
|
||||
}
|
||||
|
||||
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
|
||||
|
||||
@@ -379,15 +379,6 @@ func Test_signCert(t *testing.T) {
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test with the proper password in the environment
|
||||
os.Remove(crtF.Name())
|
||||
os.Remove(keyF.Name())
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||
assert.Empty(t, eb.String())
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||
|
||||
// test with the wrong password
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
@@ -398,17 +389,6 @@ func Test_signCert(t *testing.T) {
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test with the wrong password in environment
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||
|
||||
// test with the user not entering a password
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
@@ -20,17 +18,6 @@ import (
|
||||
// at compile-time.
|
||||
var Build string
|
||||
|
||||
func init() {
|
||||
if Build == "" {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
serviceFlag := flag.String("service", "", "Control the system service.")
|
||||
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
@@ -20,17 +18,6 @@ import (
|
||||
// at compile-time.
|
||||
var Build string
|
||||
|
||||
func init() {
|
||||
if Build == "" {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
||||
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type C struct {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
|
||||
@@ -4,16 +4,13 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
@@ -30,6 +27,12 @@ const (
|
||||
)
|
||||
|
||||
type connectionManager struct {
|
||||
in map[uint32]struct{}
|
||||
inLock *sync.RWMutex
|
||||
|
||||
out map[uint32]struct{}
|
||||
outLock *sync.RWMutex
|
||||
|
||||
// relayUsed holds which relay localIndexs are in use
|
||||
relayUsed map[uint32]struct{}
|
||||
relayUsedLock *sync.RWMutex
|
||||
@@ -37,117 +40,117 @@ type connectionManager struct {
|
||||
hostMap *HostMap
|
||||
trafficTimer *LockingTimerWheel[uint32]
|
||||
intf *Interface
|
||||
pendingDeletion map[uint32]struct{}
|
||||
punchy *Punchy
|
||||
|
||||
// Configuration settings
|
||||
checkInterval time.Duration
|
||||
pendingDeletionInterval time.Duration
|
||||
inactivityTimeout atomic.Int64
|
||||
dropInactive atomic.Bool
|
||||
|
||||
metricsTxPunchy metrics.Counter
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
||||
cm := &connectionManager{
|
||||
hostMap: hm,
|
||||
l: l,
|
||||
punchy: p,
|
||||
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
|
||||
var max time.Duration
|
||||
if checkInterval < pendingDeletionInterval {
|
||||
max = pendingDeletionInterval
|
||||
} else {
|
||||
max = checkInterval
|
||||
}
|
||||
|
||||
nc := &connectionManager{
|
||||
hostMap: intf.hostMap,
|
||||
in: make(map[uint32]struct{}),
|
||||
inLock: &sync.RWMutex{},
|
||||
out: make(map[uint32]struct{}),
|
||||
outLock: &sync.RWMutex{},
|
||||
relayUsed: make(map[uint32]struct{}),
|
||||
relayUsedLock: &sync.RWMutex{},
|
||||
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
|
||||
intf: intf,
|
||||
pendingDeletion: make(map[uint32]struct{}),
|
||||
checkInterval: checkInterval,
|
||||
pendingDeletionInterval: pendingDeletionInterval,
|
||||
punchy: punchy,
|
||||
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
||||
l: l,
|
||||
}
|
||||
|
||||
cm.reload(c, true)
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
cm.reload(c, false)
|
||||
})
|
||||
|
||||
return cm
|
||||
nc.Start(ctx)
|
||||
return nc
|
||||
}
|
||||
|
||||
func (cm *connectionManager) reload(c *config.C, initial bool) {
|
||||
if initial {
|
||||
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
|
||||
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
|
||||
|
||||
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
|
||||
// pretty close to their configured duration.
|
||||
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
|
||||
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
|
||||
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
|
||||
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("tunnels.inactivity_timeout") {
|
||||
old := cm.getInactivityTimeout()
|
||||
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
||||
if !initial {
|
||||
cm.l.WithField("oldDuration", old).
|
||||
WithField("newDuration", cm.getInactivityTimeout()).
|
||||
Info("Inactivity timeout has changed")
|
||||
}
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("tunnels.drop_inactive") {
|
||||
old := cm.dropInactive.Load()
|
||||
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
||||
if !initial {
|
||||
cm.l.WithField("oldBool", old).
|
||||
WithField("newBool", cm.dropInactive.Load()).
|
||||
Info("Drop inactive setting has changed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *connectionManager) getInactivityTimeout() time.Duration {
|
||||
return (time.Duration)(cm.inactivityTimeout.Load())
|
||||
}
|
||||
|
||||
func (cm *connectionManager) In(h *HostInfo) {
|
||||
h.in.Store(true)
|
||||
}
|
||||
|
||||
func (cm *connectionManager) Out(h *HostInfo) {
|
||||
h.out.Store(true)
|
||||
}
|
||||
|
||||
func (cm *connectionManager) RelayUsed(localIndex uint32) {
|
||||
cm.relayUsedLock.RLock()
|
||||
func (n *connectionManager) In(localIndex uint32) {
|
||||
n.inLock.RLock()
|
||||
// If this already exists, return
|
||||
if _, ok := cm.relayUsed[localIndex]; ok {
|
||||
cm.relayUsedLock.RUnlock()
|
||||
if _, ok := n.in[localIndex]; ok {
|
||||
n.inLock.RUnlock()
|
||||
return
|
||||
}
|
||||
cm.relayUsedLock.RUnlock()
|
||||
cm.relayUsedLock.Lock()
|
||||
cm.relayUsed[localIndex] = struct{}{}
|
||||
cm.relayUsedLock.Unlock()
|
||||
n.inLock.RUnlock()
|
||||
n.inLock.Lock()
|
||||
n.in[localIndex] = struct{}{}
|
||||
n.inLock.Unlock()
|
||||
}
|
||||
|
||||
func (n *connectionManager) Out(localIndex uint32) {
|
||||
n.outLock.RLock()
|
||||
// If this already exists, return
|
||||
if _, ok := n.out[localIndex]; ok {
|
||||
n.outLock.RUnlock()
|
||||
return
|
||||
}
|
||||
n.outLock.RUnlock()
|
||||
n.outLock.Lock()
|
||||
n.out[localIndex] = struct{}{}
|
||||
n.outLock.Unlock()
|
||||
}
|
||||
|
||||
func (n *connectionManager) RelayUsed(localIndex uint32) {
|
||||
n.relayUsedLock.RLock()
|
||||
// If this already exists, return
|
||||
if _, ok := n.relayUsed[localIndex]; ok {
|
||||
n.relayUsedLock.RUnlock()
|
||||
return
|
||||
}
|
||||
n.relayUsedLock.RUnlock()
|
||||
n.relayUsedLock.Lock()
|
||||
n.relayUsed[localIndex] = struct{}{}
|
||||
n.relayUsedLock.Unlock()
|
||||
}
|
||||
|
||||
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
|
||||
// resets the state for this local index
|
||||
func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
|
||||
in := h.in.Swap(false)
|
||||
out := h.out.Swap(false)
|
||||
if in || out {
|
||||
h.lastUsed = now
|
||||
}
|
||||
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
|
||||
n.inLock.Lock()
|
||||
n.outLock.Lock()
|
||||
_, in := n.in[localIndex]
|
||||
_, out := n.out[localIndex]
|
||||
delete(n.in, localIndex)
|
||||
delete(n.out, localIndex)
|
||||
n.inLock.Unlock()
|
||||
n.outLock.Unlock()
|
||||
return in, out
|
||||
}
|
||||
|
||||
// AddTrafficWatch must be called for every new HostInfo.
|
||||
// We will continue to monitor the HostInfo until the tunnel is dropped.
|
||||
func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
|
||||
if h.out.Swap(true) == false {
|
||||
cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
|
||||
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
|
||||
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
|
||||
n.outLock.Lock()
|
||||
if _, ok := n.out[localIndex]; ok {
|
||||
n.outLock.Unlock()
|
||||
return
|
||||
}
|
||||
n.out[localIndex] = struct{}{}
|
||||
n.trafficTimer.Add(localIndex, n.checkInterval)
|
||||
n.outLock.Unlock()
|
||||
}
|
||||
|
||||
func (cm *connectionManager) Start(ctx context.Context) {
|
||||
clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
|
||||
func (n *connectionManager) Start(ctx context.Context) {
|
||||
go n.Run(ctx)
|
||||
}
|
||||
|
||||
func (n *connectionManager) Run(ctx context.Context) {
|
||||
//TODO: this tick should be based on the min wheel tick? Check firewall
|
||||
clockSource := time.NewTicker(500 * time.Millisecond)
|
||||
defer clockSource.Stop()
|
||||
|
||||
p := []byte("")
|
||||
@@ -160,61 +163,61 @@ func (cm *connectionManager) Start(ctx context.Context) {
|
||||
return
|
||||
|
||||
case now := <-clockSource.C:
|
||||
cm.trafficTimer.Advance(now)
|
||||
n.trafficTimer.Advance(now)
|
||||
for {
|
||||
localIndex, has := cm.trafficTimer.Purge()
|
||||
localIndex, has := n.trafficTimer.Purge()
|
||||
if !has {
|
||||
break
|
||||
}
|
||||
|
||||
cm.doTrafficCheck(localIndex, p, nb, out, now)
|
||||
n.doTrafficCheck(localIndex, p, nb, out, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
|
||||
decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
|
||||
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
|
||||
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
|
||||
|
||||
switch decision {
|
||||
case deleteTunnel:
|
||||
if cm.hostMap.DeleteHostInfo(hostinfo) {
|
||||
if n.hostMap.DeleteHostInfo(hostinfo) {
|
||||
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
||||
cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
||||
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
||||
}
|
||||
|
||||
case closeTunnel:
|
||||
cm.intf.sendCloseTunnel(hostinfo)
|
||||
cm.intf.closeTunnel(hostinfo)
|
||||
n.intf.sendCloseTunnel(hostinfo)
|
||||
n.intf.closeTunnel(hostinfo)
|
||||
|
||||
case swapPrimary:
|
||||
cm.swapPrimary(hostinfo, primary)
|
||||
n.swapPrimary(hostinfo, primary)
|
||||
|
||||
case migrateRelays:
|
||||
cm.migrateRelayUsed(hostinfo, primary)
|
||||
n.migrateRelayUsed(hostinfo, primary)
|
||||
|
||||
case tryRehandshake:
|
||||
cm.tryRehandshake(hostinfo)
|
||||
n.tryRehandshake(hostinfo)
|
||||
|
||||
case sendTestPacket:
|
||||
cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
|
||||
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
|
||||
}
|
||||
|
||||
cm.resetRelayTrafficCheck(hostinfo)
|
||||
n.resetRelayTrafficCheck(hostinfo)
|
||||
}
|
||||
|
||||
func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
|
||||
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
|
||||
if hostinfo != nil {
|
||||
cm.relayUsedLock.Lock()
|
||||
defer cm.relayUsedLock.Unlock()
|
||||
n.relayUsedLock.Lock()
|
||||
defer n.relayUsedLock.Unlock()
|
||||
// No need to migrate any relays, delete usage info now.
|
||||
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
|
||||
delete(cm.relayUsed, idx)
|
||||
delete(n.relayUsed, idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
|
||||
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
|
||||
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
||||
|
||||
for _, r := range relayFor {
|
||||
@@ -224,51 +227,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
||||
var relayFrom netip.Addr
|
||||
var relayTo netip.Addr
|
||||
switch {
|
||||
case ok:
|
||||
switch existing.State {
|
||||
case Established, PeerRequested, Disestablished:
|
||||
case ok && existing.State == Established:
|
||||
// This relay already exists in newhostinfo, then do nothing.
|
||||
continue
|
||||
case Requested:
|
||||
case ok && existing.State == Requested:
|
||||
// The relay exists in a Requested state; re-send the request
|
||||
index = existing.LocalIndex
|
||||
switch r.Type {
|
||||
case TerminalType:
|
||||
relayFrom = cm.intf.myVpnAddrs[0]
|
||||
relayFrom = n.intf.myVpnAddrs[0]
|
||||
relayTo = existing.PeerAddr
|
||||
case ForwardingType:
|
||||
relayFrom = existing.PeerAddr
|
||||
relayTo = newhostinfo.vpnAddrs[0]
|
||||
default:
|
||||
// should never happen
|
||||
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
|
||||
}
|
||||
}
|
||||
case !ok:
|
||||
cm.relayUsedLock.RLock()
|
||||
if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
|
||||
n.relayUsedLock.RLock()
|
||||
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
|
||||
// The relay hasn't been used; don't migrate it.
|
||||
cm.relayUsedLock.RUnlock()
|
||||
n.relayUsedLock.RUnlock()
|
||||
continue
|
||||
}
|
||||
cm.relayUsedLock.RUnlock()
|
||||
n.relayUsedLock.RUnlock()
|
||||
// The relay doesn't exist at all; create some relay state and send the request.
|
||||
var err error
|
||||
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||
if err != nil {
|
||||
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
||||
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
||||
continue
|
||||
}
|
||||
switch r.Type {
|
||||
case TerminalType:
|
||||
relayFrom = cm.intf.myVpnAddrs[0]
|
||||
relayFrom = n.intf.myVpnAddrs[0]
|
||||
relayTo = r.PeerAddr
|
||||
case ForwardingType:
|
||||
relayFrom = r.PeerAddr
|
||||
relayTo = newhostinfo.vpnAddrs[0]
|
||||
default:
|
||||
// should never happen
|
||||
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,12 +279,12 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
||||
switch newhostinfo.GetCert().Certificate.Version() {
|
||||
case cert.Version1:
|
||||
if !relayFrom.Is4() {
|
||||
cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
if !relayTo.Is4() {
|
||||
cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -298,16 +296,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
||||
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
||||
default:
|
||||
newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
|
||||
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := req.Marshal()
|
||||
if err != nil {
|
||||
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
||||
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
||||
} else {
|
||||
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
cm.l.WithFields(logrus.Fields{
|
||||
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
n.l.WithFields(logrus.Fields{
|
||||
"relayFrom": req.RelayFromAddr,
|
||||
"relayTo": req.RelayToAddr,
|
||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||
@@ -318,44 +316,46 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
|
||||
// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
|
||||
cm.hostMap.RLock()
|
||||
defer cm.hostMap.RUnlock()
|
||||
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
|
||||
n.hostMap.RLock()
|
||||
defer n.hostMap.RUnlock()
|
||||
|
||||
hostinfo := cm.hostMap.Indexes[localIndex]
|
||||
hostinfo := n.hostMap.Indexes[localIndex]
|
||||
if hostinfo == nil {
|
||||
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
|
||||
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
|
||||
delete(n.pendingDeletion, localIndex)
|
||||
return doNothing, nil, nil
|
||||
}
|
||||
|
||||
if cm.isInvalidCertificate(now, hostinfo) {
|
||||
if n.isInvalidCertificate(now, hostinfo) {
|
||||
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||
return closeTunnel, hostinfo, nil
|
||||
}
|
||||
|
||||
primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||
mainHostInfo := true
|
||||
if primary != nil && primary != hostinfo {
|
||||
mainHostInfo = false
|
||||
}
|
||||
|
||||
// Check for traffic on this hostinfo
|
||||
inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
|
||||
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
|
||||
|
||||
// A hostinfo is determined alive if there is incoming traffic
|
||||
if inTraffic {
|
||||
decision := doNothing
|
||||
if cm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(cm.l).
|
||||
if n.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(n.l).
|
||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
||||
Debug("Tunnel status")
|
||||
}
|
||||
hostinfo.pendingDeletion.Store(false)
|
||||
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||
|
||||
if mainHostInfo {
|
||||
decision = tryRehandshake
|
||||
|
||||
} else {
|
||||
if cm.shouldSwapPrimary(hostinfo) {
|
||||
if n.shouldSwapPrimary(hostinfo, primary) {
|
||||
decision = swapPrimary
|
||||
} else {
|
||||
// migrate the relays to the primary, if in use.
|
||||
@@ -363,55 +363,46 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
||||
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
||||
|
||||
if !outTraffic {
|
||||
// Send a punch packet to keep the NAT state alive
|
||||
cm.sendPunch(hostinfo)
|
||||
n.sendPunch(hostinfo)
|
||||
}
|
||||
|
||||
return decision, hostinfo, primary
|
||||
}
|
||||
|
||||
if hostinfo.pendingDeletion.Load() {
|
||||
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
|
||||
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
||||
hostinfo.logger(cm.l).
|
||||
hostinfo.logger(n.l).
|
||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
||||
Info("Tunnel status")
|
||||
|
||||
delete(n.pendingDeletion, hostinfo.localIndexId)
|
||||
return deleteTunnel, hostinfo, nil
|
||||
}
|
||||
|
||||
decision := doNothing
|
||||
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
|
||||
if !outTraffic {
|
||||
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
||||
if isInactive {
|
||||
// Tunnel is inactive, tear it down
|
||||
hostinfo.logger(cm.l).
|
||||
WithField("inactiveDuration", inactiveFor).
|
||||
WithField("primary", mainHostInfo).
|
||||
Info("Dropping tunnel due to inactivity")
|
||||
|
||||
return closeTunnel, hostinfo, primary
|
||||
}
|
||||
|
||||
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
||||
// Just maintain NAT state if configured to do so.
|
||||
cm.sendPunch(hostinfo)
|
||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
||||
n.sendPunch(hostinfo)
|
||||
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
||||
return doNothing, nil, nil
|
||||
|
||||
}
|
||||
|
||||
if cm.punchy.GetTargetEverything() {
|
||||
if n.punchy.GetTargetEverything() {
|
||||
// This is similar to the old punchy behavior with a slight optimization.
|
||||
// We aren't receiving traffic but we are sending it, punch on all known
|
||||
// ips in case we need to re-prime NAT state
|
||||
cm.sendPunch(hostinfo)
|
||||
n.sendPunch(hostinfo)
|
||||
}
|
||||
|
||||
if cm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(cm.l).
|
||||
if n.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(n.l).
|
||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||
Debug("Tunnel status")
|
||||
}
|
||||
@@ -420,33 +411,17 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
decision = sendTestPacket
|
||||
|
||||
} else {
|
||||
if cm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
|
||||
if n.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
|
||||
}
|
||||
}
|
||||
|
||||
hostinfo.pendingDeletion.Store(true)
|
||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
|
||||
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
|
||||
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
|
||||
return decision, hostinfo, nil
|
||||
}
|
||||
|
||||
func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
|
||||
if cm.dropInactive.Load() == false {
|
||||
// We aren't configured to drop inactive tunnels
|
||||
return 0, false
|
||||
}
|
||||
|
||||
inactiveDuration := now.Sub(hostinfo.lastUsed)
|
||||
if inactiveDuration < cm.getInactivityTimeout() {
|
||||
// It's not considered inactive
|
||||
return inactiveDuration, false
|
||||
}
|
||||
|
||||
// The tunnel is inactive
|
||||
return inactiveDuration, true
|
||||
}
|
||||
|
||||
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
||||
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
||||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||
// Let's sort this out.
|
||||
@@ -454,127 +429,83 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
|
||||
// vpn addr is static across all tunnels for this host pair so lets
|
||||
// use that to determine if we should consider swapping.
|
||||
if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
|
||||
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
|
||||
// Their primary vpn addr is less than mine. Do not swap.
|
||||
return false
|
||||
}
|
||||
|
||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||
if crt == nil {
|
||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
||||
return true
|
||||
}
|
||||
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||
// settle down.
|
||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||
}
|
||||
|
||||
func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
cm.hostMap.Lock()
|
||||
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
n.hostMap.Lock()
|
||||
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
||||
if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
||||
cm.hostMap.unlockedMakePrimary(current)
|
||||
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
||||
n.hostMap.unlockedMakePrimary(current)
|
||||
}
|
||||
cm.hostMap.Unlock()
|
||||
n.hostMap.Unlock()
|
||||
}
|
||||
|
||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||
// check and return true.
|
||||
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||
remoteCert := hostinfo.GetCert()
|
||||
if remoteCert == nil {
|
||||
return false //don't tear down tunnels for handshakes in progress
|
||||
}
|
||||
|
||||
caPool := cm.intf.pki.GetCAPool()
|
||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||
if err == nil {
|
||||
return false //cert is still valid! yay!
|
||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||
// Block listed certificates should always be disconnected
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||
return true
|
||||
} else if cm.intf.disconnectInvalid.Load() {
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
return true
|
||||
} else {
|
||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||
return false
|
||||
}
|
||||
|
||||
caPool := n.intf.pki.GetCAPool()
|
||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
if !cm.punchy.GetPunch() {
|
||||
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||
// Block listed certificates should always be disconnected
|
||||
return false
|
||||
}
|
||||
|
||||
hostinfo.logger(n.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
if !n.punchy.GetPunch() {
|
||||
// Punching is disabled
|
||||
return
|
||||
}
|
||||
|
||||
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
||||
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
|
||||
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
|
||||
// would lose the ability to notify us and punchy.respond would become unreliable.
|
||||
return
|
||||
}
|
||||
|
||||
if cm.punchy.GetTargetEverything() {
|
||||
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||
cm.metricsTxPunchy.Inc(1)
|
||||
cm.intf.outside.WriteTo([]byte{1}, addr)
|
||||
if n.punchy.GetTargetEverything() {
|
||||
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||
n.metricsTxPunchy.Inc(1)
|
||||
n.intf.outside.WriteTo([]byte{1}, addr)
|
||||
})
|
||||
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
cm.metricsTxPunchy.Inc(1)
|
||||
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
||||
n.metricsTxPunchy.Inc(1)
|
||||
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := cm.intf.pki.getCertState()
|
||||
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := n.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
curCrtVersion := curCrt.Version()
|
||||
myCrt := cs.getCertificate(curCrtVersion)
|
||||
if myCrt == nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("reason", "local certificate removed").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
myCrt := cs.getCertificate(curCrt.Version())
|
||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||
return
|
||||
}
|
||||
peerCrt := hostinfo.ConnectionState.peerCert
|
||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
|
||||
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
if curCrtVersion < cs.initiatingVersion {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"net/netip"
|
||||
@@ -22,7 +23,7 @@ func newTestLighthouse() *LightHouse {
|
||||
addrMap: map[netip.Addr]*RemoteList{},
|
||||
queryChan: make(chan netip.Addr, 10),
|
||||
}
|
||||
lighthouses := []netip.Addr{}
|
||||
lighthouses := map[netip.Addr]struct{}{}
|
||||
staticList := map[netip.Addr]struct{}{}
|
||||
|
||||
lh.lighthouses.Store(&lighthouses)
|
||||
@@ -63,10 +64,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(l)
|
||||
punchy := NewPunchyFromConfig(l, conf)
|
||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
@@ -84,33 +85,32 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
// We saw traffic out to vpnIp
|
||||
nc.Out(hostinfo)
|
||||
nc.In(hostinfo)
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
nc.Out(hostinfo.localIndexId)
|
||||
nc.In(hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.True(t, hostinfo.out.Load())
|
||||
assert.True(t, hostinfo.in.Load())
|
||||
assert.Contains(t, nc.out, hostinfo.localIndexId)
|
||||
|
||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||
|
||||
// Do another traffic check tick, this host should be pending deletion now
|
||||
nc.Out(hostinfo)
|
||||
assert.True(t, hostinfo.out.Load())
|
||||
nc.Out(hostinfo.localIndexId)
|
||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||
assert.True(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
|
||||
// Do a final traffic check tick, the host should now be removed
|
||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
}
|
||||
|
||||
@@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(l)
|
||||
punchy := NewPunchyFromConfig(l, conf)
|
||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
@@ -167,129 +167,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
// We saw traffic out to vpnIp
|
||||
nc.Out(hostinfo)
|
||||
nc.In(hostinfo)
|
||||
assert.True(t, hostinfo.in.Load())
|
||||
assert.True(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
nc.Out(hostinfo.localIndexId)
|
||||
nc.In(hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
|
||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||
|
||||
// Do another traffic check tick, this host should be pending deletion now
|
||||
nc.Out(hostinfo)
|
||||
nc.Out(hostinfo.localIndexId)
|
||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||
assert.True(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
|
||||
// We saw traffic, should no longer be pending deletion
|
||||
nc.In(hostinfo)
|
||||
nc.In(hostinfo.localIndexId)
|
||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
}
|
||||
|
||||
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := newHostMap(l)
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
cs := &CertState{
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &test.NoopTun{},
|
||||
outside: &udp.NoopConn{},
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
pki: &PKI{},
|
||||
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
}
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(l)
|
||||
conf.Settings["tunnels"] = map[string]any{
|
||||
"drop_inactive": true,
|
||||
}
|
||||
punchy := NewPunchyFromConfig(l, conf)
|
||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||
assert.True(t, nc.dropInactive.Load())
|
||||
nc.intf = ifce
|
||||
|
||||
// Add an ip we have established a connection w/ to hostmap
|
||||
hostinfo := &HostInfo{
|
||||
vpnAddrs: vpnAddrs,
|
||||
localIndexId: 1099,
|
||||
remoteIndexId: 9901,
|
||||
}
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
myCert: &dummyCert{version: cert.Version1},
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
|
||||
nc.Out(hostinfo)
|
||||
nc.In(hostinfo)
|
||||
assert.True(t, hostinfo.out.Load())
|
||||
assert.True(t, hostinfo.in.Load())
|
||||
|
||||
now := time.Now()
|
||||
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
|
||||
assert.Equal(t, tryRehandshake, decision)
|
||||
assert.Equal(t, now, hostinfo.lastUsed)
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
|
||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
|
||||
assert.Equal(t, doNothing, decision)
|
||||
assert.Equal(t, now, hostinfo.lastUsed)
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
|
||||
// Do another traffic check tick, should still not be pending deletion
|
||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
|
||||
assert.Equal(t, doNothing, decision)
|
||||
assert.Equal(t, now, hostinfo.lastUsed)
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
|
||||
// Finally advance beyond the inactivity timeout
|
||||
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
|
||||
assert.Equal(t, closeTunnel, decision)
|
||||
assert.Equal(t, now, hostinfo.lastUsed)
|
||||
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||
assert.False(t, hostinfo.out.Load())
|
||||
assert.False(t, hostinfo.in.Load())
|
||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||
}
|
||||
@@ -360,10 +264,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
ifce.disconnectInvalid.Store(true)
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(l)
|
||||
punchy := NewPunchyFromConfig(l, conf)
|
||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
||||
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
||||
ifce.connectionManager = nc
|
||||
|
||||
hostinfo := &HostInfo{
|
||||
@@ -446,10 +350,6 @@ func (d *dummyCert) PublicKey() []byte {
|
||||
return d.publicKey
|
||||
}
|
||||
|
||||
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
|
||||
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
|
||||
}
|
||||
|
||||
func (d *dummyCert) Signature() []byte {
|
||||
return d.signature
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ type Control struct {
|
||||
statsStart func()
|
||||
dnsStart func()
|
||||
lighthouseStart func()
|
||||
connectionManagerStart func(context.Context)
|
||||
}
|
||||
|
||||
type ControlHostInfo struct {
|
||||
@@ -64,9 +63,6 @@ func (c *Control) Start() {
|
||||
if c.dnsStart != nil {
|
||||
go c.dnsStart()
|
||||
}
|
||||
if c.connectionManagerStart != nil {
|
||||
go c.connectionManagerStart(c.ctx)
|
||||
}
|
||||
if c.lighthouseStart != nil {
|
||||
c.lighthouseStart()
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
localIndexId: 201,
|
||||
vpnAddrs: []netip.Addr{vpnIp},
|
||||
relayState: RelayState{
|
||||
relays: nil,
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
localIndexId: 201,
|
||||
vpnAddrs: []netip.Addr{vpnIp2},
|
||||
relayState: RelayState{
|
||||
relays: nil,
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
|
||||
@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
|
||||
return c.f.hostMap
|
||||
}
|
||||
|
||||
func (c *Control) GetF() *Interface {
|
||||
return c.f
|
||||
}
|
||||
|
||||
func (c *Control) GetCertState() *CertState {
|
||||
return c.f.pki.getCertState()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -39,7 +40,7 @@ func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecord
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||
func (d *dnsRecords) query(q uint16, data string) netip.Addr {
|
||||
data = strings.ToLower(data)
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
@@ -57,7 +58,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) QueryCert(data string) string {
|
||||
func (d *dnsRecords) queryCert(data string) string {
|
||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||
if err != nil {
|
||||
return ""
|
||||
@@ -122,7 +123,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||
ip := d.Query(q.Qtype, q.Name)
|
||||
ip := d.query(q.Qtype, q.Name)
|
||||
if ip.IsValid() {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||
if err == nil {
|
||||
@@ -135,7 +136,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
return
|
||||
}
|
||||
d.l.Debugf("Query for TXT %s", q.Name)
|
||||
ip := d.QueryCert(q.Name)
|
||||
ip := d.queryCert(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||
if err == nil {
|
||||
@@ -163,18 +164,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||
func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||
dnsR = newDnsRecords(l, cs, hostMap)
|
||||
|
||||
// attach request handler func
|
||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
reloadDns(l, c)
|
||||
reloadDns(ctx, l, c)
|
||||
})
|
||||
|
||||
return func() {
|
||||
startDns(l, c)
|
||||
startDns(ctx, l, c)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,24 +188,24 @@ func getDnsServerAddr(c *config.C) string {
|
||||
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
||||
}
|
||||
|
||||
func startDns(l *logrus.Logger, c *config.C) {
|
||||
func startDns(ctx context.Context, l *logrus.Logger, c *config.C) {
|
||||
dnsAddr = getDnsServerAddr(c)
|
||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
||||
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
||||
err := dnsServer.ListenAndServe()
|
||||
defer dnsServer.Shutdown()
|
||||
defer dnsServer.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
l.Errorf("Failed to start server: %s\n ", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func reloadDns(l *logrus.Logger, c *config.C) {
|
||||
func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) {
|
||||
if dnsAddr == getDnsServerAddr(c) {
|
||||
l.Debug("No DNS server config change detected")
|
||||
return
|
||||
}
|
||||
|
||||
l.Debug("Restarting DNS server")
|
||||
dnsServer.Shutdown()
|
||||
go startDns(l, c)
|
||||
dnsServer.ShutdownContext(ctx)
|
||||
go startDns(ctx, l, c)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func BenchmarkHotPath(b *testing.B) {
|
||||
@@ -97,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestGoodHandshakeNoOverlap(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
|
||||
|
||||
// Put their info in our lighthouse
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
empty := []byte{}
|
||||
t.Log("do something to cause a handshake")
|
||||
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
|
||||
|
||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||
|
||||
t.Log("Get their stage 1 packet")
|
||||
stage1Packet := theirControl.GetFromUDP(true)
|
||||
|
||||
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
|
||||
myControl.InjectUDPPacket(stage1Packet)
|
||||
|
||||
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
|
||||
myControl.WaitForType(header.Test, 0, theirControl)
|
||||
|
||||
t.Log("Make sure our host infos are correct")
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestWrongResponderHandshake(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
@@ -499,35 +464,6 @@ func TestRelays(t *testing.T) {
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
}
|
||||
|
||||
func TestRelaysDontCareAboutIps(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
relayControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
}
|
||||
|
||||
func TestReestablishRelays(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
@@ -570,7 +506,7 @@ func TestReestablishRelays(t *testing.T) {
|
||||
curIndexes := len(myControl.GetHostmap().Indexes)
|
||||
for curIndexes >= start {
|
||||
curIndexes = len(myControl.GetHostmap().Indexes)
|
||||
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
|
||||
r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
|
||||
|
||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||
@@ -1116,9 +1052,6 @@ func TestRehandshakingLoser(t *testing.T) {
|
||||
t.Log("Stand up a tunnel between me and them")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
|
||||
myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
|
||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||
|
||||
r.Log("Renew their certificate and spin until mine sees it")
|
||||
@@ -1291,109 +1224,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
||||
|
||||
o := m{
|
||||
"static_host_map": m{
|
||||
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
|
||||
},
|
||||
"lighthouse": m{
|
||||
"hosts": []string{lhVpnIpNet[0].Addr().String()},
|
||||
"local_allow_list": m{
|
||||
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
|
||||
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
|
||||
"10.0.0.0/24": true,
|
||||
"::/0": false,
|
||||
},
|
||||
},
|
||||
}
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
|
||||
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, lhControl, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
// Start the servers
|
||||
lhControl.Start()
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Stand up an ipv6 tunnel between me and them")
|
||||
assert.True(t, myVpnIpNet[1].Addr().Is6())
|
||||
assert.True(t, theirVpnIpNet[1].Addr().Is6())
|
||||
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
|
||||
|
||||
lhControl.Stop()
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
||||
unsafePrefix := "192.168.6.0/24"
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
|
||||
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
|
||||
myCfg := m{
|
||||
"tun": m{
|
||||
"unsafe_routes": []m{route},
|
||||
},
|
||||
}
|
||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
|
||||
t.Logf("my config %v", myConfig)
|
||||
// Put their info in our lighthouse
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
|
||||
spookyDest := netip.MustParseAddr("192.168.6.4")
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
||||
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||
|
||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||
|
||||
t.Log("Get their stage 1 packet so that we can play with it")
|
||||
stage1Packet := theirControl.GetFromUDP(true)
|
||||
|
||||
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
|
||||
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
|
||||
badPacket := stage1Packet.Copy()
|
||||
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
|
||||
myControl.InjectUDPPacket(badPacket)
|
||||
|
||||
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
|
||||
myControl.InjectUDPPacket(stage1Packet)
|
||||
|
||||
t.Log("Wait until we see my cached packet come through")
|
||||
myControl.WaitForType(1, 0, theirControl)
|
||||
|
||||
t.Log("Make sure our host infos are correct")
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||
|
||||
t.Log("Get that cached packet and make sure it looks right")
|
||||
myCachedPacket := theirControl.GetFromTun(true)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
|
||||
|
||||
//reply
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
|
||||
//wait for reply
|
||||
theirControl.WaitForType(1, 0, myControl)
|
||||
theirCachedPacket := myControl.GetFromTun(true)
|
||||
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
|
||||
|
||||
t.Log("Do a bidirectional tunnel test")
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
@@ -22,14 +22,15 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type m = map[string]any
|
||||
|
||||
// newSimpleServer creates a nebula instance with many assumptions
|
||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
@@ -55,54 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||
}
|
||||
|
||||
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
|
||||
}
|
||||
|
||||
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||
}
|
||||
|
||||
if len(vpnNetworks) == 0 {
|
||||
panic("no vpn networks")
|
||||
}
|
||||
|
||||
firewallInbound := []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}}
|
||||
|
||||
var unsafeNetworks []netip.Prefix
|
||||
if sUnsafeNetworks != "" {
|
||||
firewallInbound = []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
"local_cidr": "0.0.0.0/0",
|
||||
}}
|
||||
|
||||
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
|
||||
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
unsafeNetworks = append(unsafeNetworks, x)
|
||||
}
|
||||
}
|
||||
|
||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
|
||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||
|
||||
caB, err := caCrt.MarshalPEM()
|
||||
if err != nil {
|
||||
@@ -122,7 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": firewallInbound,
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
},
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
@@ -171,109 +129,6 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
// newServer creates a nebula instance with fewer assumptions
|
||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
vpnNetworks := certs[len(certs)-1].Networks()
|
||||
|
||||
var udpAddr netip.AddrPort
|
||||
if vpnNetworks[0].Addr().Is4() {
|
||||
budpIp := vpnNetworks[0].Addr().As4()
|
||||
budpIp[1] -= 128
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||
} else {
|
||||
budpIp := vpnNetworks[0].Addr().As16()
|
||||
// beef for funsies
|
||||
budpIp[2] = 190
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
|
||||
caStr := ""
|
||||
for _, ca := range caCrt {
|
||||
x, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr += string(x)
|
||||
}
|
||||
certStr := ""
|
||||
for _, c := range certs {
|
||||
x, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
certStr += string(x)
|
||||
}
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": certStr,
|
||||
"key": string(key),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": m{
|
||||
"outbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
},
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": m{
|
||||
"host": udpAddr.Addr().String(),
|
||||
"port": udpAddr.Port(),
|
||||
},
|
||||
"logging": m{
|
||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||
"level": l.Level.String(),
|
||||
},
|
||||
"timers": m{
|
||||
"pending_deletion_interval": 2,
|
||||
"connection_alive_interval": 2,
|
||||
},
|
||||
}
|
||||
|
||||
if overrides != nil {
|
||||
final := m{}
|
||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mc = final
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
c := config.NewC(l)
|
||||
cStr := string(cb)
|
||||
c.LoadString(cStr)
|
||||
|
||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
type doneCb func()
|
||||
|
||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||
@@ -308,10 +163,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
|
||||
// Get both host infos
|
||||
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
||||
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
||||
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||
|
||||
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
||||
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||
|
||||
// Check that both vpn and real addr are correct
|
||||
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
||||
|
||||
@@ -700,7 +700,6 @@ func (r *R) FlushAll() {
|
||||
r.Unlock()
|
||||
panic("Can't FlushAll for host: " + p.To.String())
|
||||
}
|
||||
receiver.InjectUDPPacket(p)
|
||||
r.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,367 +0,0 @@
|
||||
//go:build e2e_testing
|
||||
// +build e2e_testing
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestDropInactiveTunnels(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
|
||||
r.Log("Go inactive and wait for the tunnels to get dropped")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
myIndexes := len(myControl.GetHostmap().Indexes)
|
||||
theirIndexes := len(theirControl.GetHostmap().Indexes)
|
||||
if myIndexes == 0 && theirIndexes == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
since := time.Since(waitStart)
|
||||
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
|
||||
if since > time.Second*30 {
|
||||
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
}
|
||||
|
||||
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertUpgrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCert2Pem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v2-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
if c == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
r.Logf("version %d", version)
|
||||
if version == cert.Version2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*10 {
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertDowngrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCertPem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v1-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == cert.Version1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertMismatchCorrection(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == theirVersion {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("wtf")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCrossStackRelaysWork(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||
|
||||
//myVpnV4 := myVpnIpNet[0]
|
||||
myVpnV6 := myVpnIpNet[1]
|
||||
relayVpnV4 := relayVpnIpNet[0]
|
||||
relayVpnV6 := relayVpnIpNet[1]
|
||||
theirVpnV6 := theirVpnIpNet[0]
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
relayControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||
|
||||
t.Log("reply?")
|
||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||
p = r.RouteForAllUntilTxTun(myControl)
|
||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
//t.Log("finish up")
|
||||
//myControl.Stop()
|
||||
//theirControl.Stop()
|
||||
//relayControl.Stop()
|
||||
}
|
||||
@@ -275,10 +275,6 @@ tun:
|
||||
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
|
||||
# in nebula configuration files. Default false, not reloadable.
|
||||
#use_system_route_table: false
|
||||
# Buffer size for reading routes updates. 0 means default system buffer size. (/proc/sys/net/core/rmem_default).
|
||||
# If using massive routes updates, for example BGP, you may need to increase this value to avoid packet loss.
|
||||
# SO_RCVBUFFORCE is used to avoid having to raise the system wide max
|
||||
#use_system_route_table_buffer_size: 0
|
||||
|
||||
# Configure logging level
|
||||
logging:
|
||||
@@ -338,18 +334,6 @@ logging:
|
||||
# after receiving the response for lighthouse queries
|
||||
#trigger_buffer: 64
|
||||
|
||||
# Tunnel manager settings
|
||||
#tunnels:
|
||||
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
|
||||
# elapsed.
|
||||
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
|
||||
# This setting is reloadable
|
||||
#drop_inactive: false
|
||||
|
||||
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
|
||||
# inactive and eligible to be dropped.
|
||||
# This setting is reloadable
|
||||
#inactivity_timeout: 10m
|
||||
|
||||
# Nebula security group configuration
|
||||
firewall:
|
||||
|
||||
@@ -5,12 +5,8 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/service"
|
||||
)
|
||||
|
||||
@@ -63,16 +59,7 @@ pki:
|
||||
if err := cfg.LoadString(configStr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger := logrus.New()
|
||||
logger.Out = os.Stdout
|
||||
|
||||
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svc, err := service.New(ctrl)
|
||||
svc, err := service.New(&cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
29
firewall.go
29
firewall.go
@@ -417,10 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrUnknownNetworkType = errors.New("unknown network type")
|
||||
var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
|
||||
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
|
||||
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
||||
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
@@ -431,31 +429,18 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
// Make sure remote address matches nebula certificate
|
||||
if h.networks != nil {
|
||||
if !h.networks.Contains(fp.RemoteAddr) {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
switch nwType {
|
||||
case NetworkTypeVPN:
|
||||
break // nothing special
|
||||
case NetworkTypeVPNPeer:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||
case NetworkTypeUnsafe:
|
||||
break // nothing special, one day this may have different FW rules
|
||||
default:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrUnknownNetworkType //should never happen
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
|
||||
440
firewall_test.go
440
firewall_test.go
@@ -8,8 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
@@ -70,9 +68,6 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
||||
require.NoError(t, err)
|
||||
|
||||
ti6, err := netip.ParsePrefix("fd12::34/128")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
// An empty rule is any
|
||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||
@@ -97,22 +92,10 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
|
||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
@@ -134,13 +117,6 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
anyIp6, err := netip.ParsePrefix("::/0")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -151,8 +127,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -177,7 +152,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, &c)
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -224,85 +199,6 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropV6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
c := dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &cert.CachedCertificate{
|
||||
Certificate: &c,
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, &c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
f := &Firewall{}
|
||||
ft := FirewallTable{
|
||||
@@ -312,10 +208,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
||||
|
||||
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
b.Run("fail on proto", func(b *testing.B) {
|
||||
@@ -347,15 +239,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{},
|
||||
}
|
||||
ip := netip.MustParsePrefix("fd99::99/128")
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -369,18 +252,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -394,18 +265,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -430,17 +289,6 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "nope",
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on name", func(b *testing.B) {
|
||||
c := &cert.CachedCertificate{
|
||||
@@ -459,8 +307,6 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -486,7 +332,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -501,7 +347,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
peerCert: &c1,
|
||||
},
|
||||
}
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -518,8 +364,6 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -551,7 +395,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
c2 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -566,7 +410,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
|
||||
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||
|
||||
c3 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -581,7 +425,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -603,50 +447,10 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3V6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||
LocalPort: 1,
|
||||
RemotePort: 1,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
network := netip.MustParsePrefix("fd12::34/120")
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host-owner",
|
||||
networks: []netip.Prefix{network},
|
||||
},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
|
||||
// Test a remote address match
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
cp := cert.NewCAPool()
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -673,7 +477,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -706,52 +510,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host-owner",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
|
||||
},
|
||||
}
|
||||
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host",
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
|
||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
|
||||
},
|
||||
}
|
||||
h1 := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c1,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||
}
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("192.0.2.1"),
|
||||
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
|
||||
LocalPort: 1,
|
||||
RemotePort: 1,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
ml := func(m map[string]struct{}, a [][]string) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
@@ -969,21 +727,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||
|
||||
// Test adding rule with cidr ipv6
|
||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding rule with local_cidr ipv6
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
@@ -1063,171 +806,6 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
assert.Equal(t, "group1", r.Group)
|
||||
}
|
||||
|
||||
type testcase struct {
|
||||
h *HostInfo
|
||||
p firewall.Packet
|
||||
c cert.Certificate
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
||||
t.Helper()
|
||||
cp := cert.NewCAPool()
|
||||
resetConntrack(fw)
|
||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
||||
if c.err == nil {
|
||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
||||
} else {
|
||||
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
||||
c1 := dummyCert{
|
||||
name: "host1",
|
||||
networks: theirPrefixes,
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &cert.CachedCertificate{
|
||||
Certificate: &c1,
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
},
|
||||
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
|
||||
}
|
||||
for i := range theirPrefixes {
|
||||
h.vpnAddrs[i] = theirPrefixes[i].Addr()
|
||||
}
|
||||
h.buildNetworks(setup.myVpnNetworksTable, &c1)
|
||||
p := firewall.Packet{
|
||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
||||
RemoteAddr: theirPrefixes[0].Addr(),
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
return testcase{
|
||||
h: &h,
|
||||
p: p,
|
||||
c: &c1,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
type testsetup struct {
|
||||
c dummyCert
|
||||
myVpnNetworksTable *bart.Lite
|
||||
fw *Firewall
|
||||
}
|
||||
|
||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
||||
c := dummyCert{
|
||||
name: "me",
|
||||
networks: myPrefixes,
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
|
||||
return newSetupFromCert(t, l, c)
|
||||
}
|
||||
|
||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
for _, prefix := range c.Networks() {
|
||||
myVpnNetworksTable.Insert(prefix)
|
||||
}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
|
||||
return testsetup{
|
||||
c: c,
|
||||
fw: fw,
|
||||
myVpnNetworksTable: myVpnNetworksTable,
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
||||
t.Run("allow inbound all matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("allow inbound local matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
|
||||
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("block inbound remote mismatched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
|
||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("Block a vpn peer packet", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
twoPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
|
||||
}
|
||||
t.Run("allow inbound one matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, nil, twoPrefixes...)
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("block inbound multimismatch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := newSetup(t, l, myPrefix)
|
||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
|
||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
||||
tc.Test(t, setup.fw)
|
||||
})
|
||||
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
|
||||
tc := buildTestCase(setup2, nil, twoPrefixes...)
|
||||
tc.p.RemoteAddr = twoPrefixes[1].Addr()
|
||||
tc.Test(t, setup2.fw)
|
||||
})
|
||||
t.Run("allow inbound unsafe route", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
|
||||
c := dummyCert{
|
||||
name: "me",
|
||||
networks: []netip.Prefix{myPrefix},
|
||||
unsafeNetworks: []netip.Prefix{unsafePrefix},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
unsafeSetup := newSetupFromCert(t, l, c)
|
||||
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
|
||||
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
||||
tc.err = ErrNoMatchingRule
|
||||
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
||||
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
|
||||
tc.err = nil
|
||||
tc.Test(t, unsafeSetup.fw) //should pass
|
||||
})
|
||||
}
|
||||
|
||||
type addRuleCall struct {
|
||||
incoming bool
|
||||
proto uint8
|
||||
|
||||
44
go.mod
44
go.mod
@@ -1,38 +1,39 @@
|
||||
module github.com/slackhq/nebula
|
||||
|
||||
go 1.25
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.1
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.2
|
||||
dario.cat/mergo v1.0.1
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
||||
github.com/armon/go-radix v1.0.0
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/gaissmai/bart v0.26.0
|
||||
github.com/gaissmai/bart v0.20.4
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/kardianos/service v1.2.4
|
||||
github.com/miekg/dns v1.1.68
|
||||
github.com/kardianos/service v1.2.2
|
||||
github.com/miekg/dns v1.1.65
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
go.yaml.in/yaml/v3 v3.0.4
|
||||
golang.org/x/crypto v0.44.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/sync v0.18.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.org/x/term v0.37.0
|
||||
golang.org/x/net v0.39.0
|
||||
golang.org/x/sync v0.13.0
|
||||
golang.org/x/sys v0.32.0
|
||||
golang.org/x/term v0.31.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/protobuf v1.36.10
|
||||
google.golang.org/protobuf v1.36.6
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||
)
|
||||
@@ -44,12 +45,11 @@ require (
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
golang.org/x/mod v0.24.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/mod v0.23.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.33.0 // indirect
|
||||
golang.org/x/tools v0.30.0 // indirect
|
||||
)
|
||||
|
||||
83
go.sum
83
go.sum
@@ -1,6 +1,6 @@
|
||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||
@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
|
||||
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
|
||||
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
@@ -64,8 +64,8 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
|
||||
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
|
||||
github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
|
||||
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
|
||||
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
|
||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
|
||||
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
|
||||
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -106,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
|
||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
|
||||
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
||||
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
@@ -143,35 +143,29 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -182,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -191,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -203,17 +197,18 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
|
||||
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -224,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -244,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
157
handshake_ix.go
157
handshake_ix.go
@@ -2,6 +2,7 @@ package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
@@ -22,19 +23,15 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.initiatingVersion
|
||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||
v = hh.initiatingVersionOverride
|
||||
} else if v < cert.Version2 {
|
||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crt := cs.getCertificate(v)
|
||||
if crt == nil {
|
||||
@@ -51,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return false
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||
@@ -107,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", cs.initiatingVersion).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
return
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||
@@ -148,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, fperr := rc.Fingerprint()
|
||||
if fperr != nil {
|
||||
fp, err := rc.Fingerprint()
|
||||
if err != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
@@ -168,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if myCertOtherVersion == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithError(err).WithFields(m{
|
||||
"udpAddr": addr,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert": remoteCert,
|
||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if rc == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Unable to handshake with host due to missing certificate version")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = myCertOtherVersion
|
||||
}
|
||||
ci.myCert = rc
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
@@ -191,17 +183,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
return
|
||||
}
|
||||
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
certName := remoteCert.Certificate.Name()
|
||||
certVersion := remoteCert.Certificate.Version()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
|
||||
anyVpnAddrsInCommon := false
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||
for _, network := range remoteCert.Certificate.Networks() {
|
||||
vpnAddr := network.Addr()
|
||||
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
@@ -209,10 +201,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
return
|
||||
}
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
continue
|
||||
}
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
@@ -243,36 +249,32 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
lastHandshakeTime: hs.Details.Time,
|
||||
relayState: RelayState{
|
||||
relays: nil,
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}
|
||||
|
||||
msgRxL := f.l.WithFields(m{
|
||||
"vpnAddrs": vpnAddrs,
|
||||
"udpAddr": addr,
|
||||
"certName": certName,
|
||||
"certVersion": certVersion,
|
||||
"fingerprint": fingerprint,
|
||||
"issuer": issuer,
|
||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
||||
"responderIndex": hs.Details.ResponderIndex,
|
||||
"remoteIndex": h.RemoteIndex,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
})
|
||||
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
//todo warn if not lighthouse or relay?
|
||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||
}
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message received")
|
||||
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||
if hs.Details.Cert == nil {
|
||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("certVersion", ci.myCert.Version()).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return
|
||||
}
|
||||
@@ -330,7 +332,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||
hostinfo.SetRemote(addr)
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
|
||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||
if err != nil {
|
||||
@@ -455,9 +457,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
||||
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
hostinfo.remotes.ResetBlockedRemotes()
|
||||
|
||||
return
|
||||
}
|
||||
@@ -571,22 +573,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
}
|
||||
|
||||
correctHostResponded := false
|
||||
anyVpnAddrsInCommon := false
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
for _, network := range vpnNetworks {
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
vpnAddr := network.Addr()
|
||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
continue
|
||||
}
|
||||
if hostinfo.vpnAddrs[0] == network.Addr() {
|
||||
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
||||
correctHostResponded = true
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure the right host responded
|
||||
if !correctHostResponded {
|
||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||
WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
@@ -598,7 +609,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
|
||||
// Create a new hostinfo/handshake for the intended vpn ip
|
||||
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||
// Block the current used address
|
||||
newHH.hostinfo.remotes = hostinfo.remotes
|
||||
@@ -625,7 +635,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
@@ -633,21 +643,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("durationNs", duration).
|
||||
WithField("sentCachedPackets", len(hh.packetStore))
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
//todo warn if not lighthouse or relay?
|
||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||
}
|
||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||
Info("Handshake message received")
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
|
||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
||||
@@ -662,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||
}
|
||||
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
hostinfo.remotes.ResetBlockedRemotes()
|
||||
f.metricHandshakes.Update(duration)
|
||||
|
||||
return false
|
||||
|
||||
@@ -70,7 +70,6 @@ type HandshakeHostInfo struct {
|
||||
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
@@ -269,12 +268,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||
// Send a RelayRequest to all known Relay IP's
|
||||
for _, relay := range hostinfo.remotes.relays {
|
||||
// Don't relay through the host I'm trying to connect to
|
||||
// Don't relay to myself
|
||||
if relay == vpnIp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't relay to myself
|
||||
// Don't relay through the host I'm trying to connect to
|
||||
if hm.f.myVpnAddrsTable.Contains(relay) {
|
||||
continue
|
||||
}
|
||||
@@ -451,7 +450,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
|
||||
vpnAddrs: []netip.Addr{vpnAddr},
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
relayState: RelayState{
|
||||
relays: nil,
|
||||
relays: map[netip.Addr]struct{}{},
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
|
||||
77
hostmap.go
77
hostmap.go
@@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -17,10 +16,12 @@ import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
// const ProbeLen = 100
|
||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
||||
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
||||
const MaxRemotes = 10
|
||||
const maxRecvError = 4
|
||||
|
||||
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
||||
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
||||
@@ -67,7 +68,7 @@ type HostMap struct {
|
||||
type RelayState struct {
|
||||
sync.RWMutex
|
||||
|
||||
relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
|
||||
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
|
||||
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
|
||||
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
|
||||
// the RelayState Lock held)
|
||||
@@ -78,12 +79,7 @@ type RelayState struct {
|
||||
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
for idx, val := range rs.relays {
|
||||
if val == ip {
|
||||
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
delete(rs.relays, ip)
|
||||
}
|
||||
|
||||
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
|
||||
@@ -128,16 +124,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
|
||||
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
if !slices.Contains(rs.relays, ip) {
|
||||
rs.relays = append(rs.relays, ip)
|
||||
}
|
||||
rs.relays[ip] = struct{}{}
|
||||
}
|
||||
|
||||
func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
||||
ret := make([]netip.Addr, len(rs.relays))
|
||||
rs.RLock()
|
||||
defer rs.RUnlock()
|
||||
copy(ret, rs.relays)
|
||||
ret := make([]netip.Addr, 0, len(rs.relays))
|
||||
for ip := range rs.relays {
|
||||
ret = append(ret, ip)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
@@ -212,18 +208,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
||||
rs.relayForByIdx[idx] = r
|
||||
}
|
||||
|
||||
type NetworkType uint8
|
||||
|
||||
const (
|
||||
NetworkTypeUnknown NetworkType = iota
|
||||
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
||||
NetworkTypeVPN
|
||||
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
||||
NetworkTypeVPNPeer
|
||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||
NetworkTypeUnsafe
|
||||
)
|
||||
|
||||
type HostInfo struct {
|
||||
remote netip.AddrPort
|
||||
remotes *RemoteList
|
||||
@@ -236,9 +220,10 @@ type HostInfo struct {
|
||||
// The host may have other vpn addresses that are outside our
|
||||
// vpn networks but were removed because they are not usable
|
||||
vpnAddrs []netip.Addr
|
||||
recvError atomic.Uint32
|
||||
|
||||
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
||||
networks *bart.Table[NetworkType]
|
||||
// networks are both all vpn and unsafe networks assigned to this host
|
||||
networks *bart.Lite
|
||||
relayState RelayState
|
||||
|
||||
// HandshakePacket records the packets used to create this hostinfo
|
||||
@@ -265,14 +250,6 @@ type HostInfo struct {
|
||||
// Used to track other hostinfos for this vpn ip since only 1 can be primary
|
||||
// Synchronised via hostmap lock and not the hostinfo lock.
|
||||
next, prev *HostInfo
|
||||
|
||||
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
|
||||
in, out, pendingDeletion atomic.Bool
|
||||
|
||||
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
|
||||
// This value will be behind against actual tunnel utilization in the hot path.
|
||||
// This should only be used by the ConnectionManagers ticker routine.
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
type ViaSender struct {
|
||||
@@ -742,26 +719,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
||||
return false
|
||||
}
|
||||
|
||||
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
|
||||
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
|
||||
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
||||
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
|
||||
return // Simple case, no BART needed
|
||||
func (i *HostInfo) RecvErrorExceeded() bool {
|
||||
if i.recvError.Add(1) >= maxRecvError {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
i.networks = new(bart.Table[NetworkType])
|
||||
for _, network := range c.Networks() {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
if myVpnNetworksTable.Contains(network.Addr()) {
|
||||
i.networks.Insert(nprefix, NetworkTypeVPN)
|
||||
} else {
|
||||
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
|
||||
}
|
||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||
// Simple case, no CIDRTree needed
|
||||
return
|
||||
}
|
||||
|
||||
for _, network := range c.UnsafeNetworks() {
|
||||
i.networks.Insert(network, NetworkTypeUnsafe)
|
||||
i.networks = new(bart.Lite)
|
||||
for _, network := range networks {
|
||||
i.networks.Insert(network)
|
||||
}
|
||||
|
||||
for _, network := range unsafeNetworks {
|
||||
i.networks.Insert(network)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostMap_MakePrimary(t *testing.T) {
|
||||
@@ -216,31 +215,3 @@ func TestHostMap_reload(t *testing.T) {
|
||||
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
||||
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
||||
}
|
||||
|
||||
func TestHostMap_RelayState(t *testing.T) {
|
||||
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
|
||||
a1 := netip.MustParseAddr("::1")
|
||||
a2 := netip.MustParseAddr("2001::1")
|
||||
|
||||
h1.relayState.InsertRelayTo(a1)
|
||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
||||
h1.relayState.InsertRelayTo(a2)
|
||||
assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
|
||||
// Ensure that the first relay added is the first one returned in the copy
|
||||
currentRelays := h1.relayState.CopyRelayIps()
|
||||
require.Len(t, currentRelays, 2)
|
||||
assert.Equal(t, a1, currentRelays[0])
|
||||
|
||||
// Deleting the last one in the list works ok
|
||||
h1.relayState.DeleteRelay(a2)
|
||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
||||
|
||||
// Deleting an element not in the list works ok
|
||||
h1.relayState.DeleteRelay(a2)
|
||||
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
||||
|
||||
// Deleting the only element in the list works ok
|
||||
h1.relayState.DeleteRelay(a1)
|
||||
assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
|
||||
|
||||
}
|
||||
|
||||
15
inside.go
15
inside.go
@@ -120,10 +120,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||
}
|
||||
|
||||
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
|
||||
// it does not check if it is within our vpn networks!
|
||||
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||
f.handshakeManager.GetOrHandshake(vpnAddr, nil)
|
||||
f.getOrHandshakeNoRouting(vpnAddr, nil)
|
||||
}
|
||||
|
||||
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
||||
@@ -139,6 +138,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
|
||||
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
||||
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
|
||||
destinationAddr := fwPacket.RemoteAddr
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
||||
@@ -231,10 +231,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
}
|
||||
|
||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
|
||||
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
|
||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
@@ -289,7 +288,7 @@ func (f *Interface) SendVia(via *HostInfo,
|
||||
c := via.ConnectionState.messageCounter.Add(1)
|
||||
|
||||
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
|
||||
f.connectionManager.Out(via)
|
||||
f.connectionManager.Out(via.localIndexId)
|
||||
|
||||
// Authenticate the header and payload, but do not encrypt for this message type.
|
||||
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
|
||||
@@ -357,7 +356,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
|
||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
||||
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
||||
f.connectionManager.Out(hostinfo)
|
||||
f.connectionManager.Out(hostinfo.localIndexId)
|
||||
|
||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||
// all our addrs and enable a faster roaming.
|
||||
|
||||
11
interface.go
11
interface.go
@@ -28,12 +28,12 @@ type InterfaceConfig struct {
|
||||
Outside udp.Conn
|
||||
Inside overlay.Device
|
||||
pki *PKI
|
||||
Cipher string
|
||||
Firewall *Firewall
|
||||
ServeDns bool
|
||||
HandshakeManager *HandshakeManager
|
||||
lightHouse *LightHouse
|
||||
connectionManager *connectionManager
|
||||
checkInterval time.Duration
|
||||
pendingDeletionInterval time.Duration
|
||||
DropLocalBroadcast bool
|
||||
DropMulticast bool
|
||||
routines int
|
||||
@@ -157,9 +157,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
if c.Firewall == nil {
|
||||
return nil, errors.New("no firewall rules")
|
||||
}
|
||||
if c.connectionManager == nil {
|
||||
return nil, errors.New("no connection manager")
|
||||
}
|
||||
|
||||
cs := c.pki.getCertState()
|
||||
ifce := &Interface{
|
||||
@@ -184,7 +181,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
||||
relayManager: c.relayManager,
|
||||
connectionManager: c.connectionManager,
|
||||
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
|
||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
@@ -201,7 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||
|
||||
ifce.connectionManager.intf = ifce
|
||||
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
|
||||
|
||||
return ifce, nil
|
||||
}
|
||||
|
||||
240
lighthouse.go
240
lighthouse.go
@@ -24,7 +24,6 @@ import (
|
||||
)
|
||||
|
||||
var ErrHostNotKnown = errors.New("host not known")
|
||||
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
|
||||
|
||||
type LightHouse struct {
|
||||
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
||||
@@ -57,7 +56,7 @@ type LightHouse struct {
|
||||
// staticList exists to avoid having a bool in each addrMap entry
|
||||
// since static should be rare
|
||||
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
||||
lighthouses atomic.Pointer[[]netip.Addr]
|
||||
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
|
||||
|
||||
interval atomic.Int64
|
||||
updateCancel context.CancelFunc
|
||||
@@ -108,7 +107,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||
l: l,
|
||||
}
|
||||
lighthouses := make([]netip.Addr, 0)
|
||||
lighthouses := make(map[netip.Addr]struct{})
|
||||
h.lighthouses.Store(&lighthouses)
|
||||
staticList := make(map[netip.Addr]struct{})
|
||||
h.staticList.Store(&staticList)
|
||||
@@ -144,7 +143,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
|
||||
return *lh.staticList.Load()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) GetLighthouses() []netip.Addr {
|
||||
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
|
||||
return *lh.lighthouses.Load()
|
||||
}
|
||||
|
||||
@@ -307,12 +306,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("lighthouse.hosts") {
|
||||
lhList, err := lh.parseLighthouses(c)
|
||||
lhMap := make(map[netip.Addr]struct{})
|
||||
err := lh.parseLighthouses(c, lhMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lh.lighthouses.Store(&lhList)
|
||||
lh.lighthouses.Store(&lhMap)
|
||||
if !initial {
|
||||
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
||||
lh.l.Info("lighthouse.hosts has changed")
|
||||
@@ -346,38 +346,36 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
||||
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
|
||||
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||
if lh.amLighthouse && len(lhs) != 0 {
|
||||
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||
}
|
||||
out := make([]netip.Addr, len(lhs))
|
||||
|
||||
for i, host := range lhs {
|
||||
addr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||
}
|
||||
|
||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
||||
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||
}
|
||||
out[i] = addr
|
||||
lhMap[addr] = struct{}{}
|
||||
}
|
||||
|
||||
if !lh.amLighthouse && len(out) == 0 {
|
||||
if !lh.amLighthouse && len(lhMap) == 0 {
|
||||
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
||||
}
|
||||
|
||||
staticList := lh.GetStaticHostList()
|
||||
for i := range out {
|
||||
if _, ok := staticList[out[i]]; !ok {
|
||||
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
|
||||
for lhAddr, _ := range lhMap {
|
||||
if _, ok := staticList[lhAddr]; !ok {
|
||||
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
||||
@@ -432,8 +430,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
||||
}
|
||||
|
||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
||||
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
||||
}
|
||||
|
||||
vals, ok := v.([]any)
|
||||
@@ -489,7 +486,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
|
||||
lh.Lock()
|
||||
defer lh.Unlock()
|
||||
// Add an entry if we don't already have one
|
||||
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
|
||||
return lh.unlockedGetRemoteList(vpnAddrs)
|
||||
}
|
||||
|
||||
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
||||
@@ -522,15 +519,11 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
|
||||
}
|
||||
|
||||
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
||||
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
|
||||
staticList := lh.GetStaticHostList()
|
||||
for _, addr := range allVpnAddrs {
|
||||
if _, ok := staticList[addr]; ok {
|
||||
// First we check the static mapping
|
||||
// and do nothing if it is there
|
||||
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// None of the VpnAddrs were present. Now we can do the deletes.
|
||||
lh.Lock()
|
||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||
if ok {
|
||||
@@ -572,7 +565,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
|
||||
am.unlockedSetHostnamesResults(hr)
|
||||
|
||||
for _, addrPort := range hr.GetAddrs() {
|
||||
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
|
||||
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
@@ -634,30 +627,23 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
|
||||
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
||||
}
|
||||
|
||||
// unlockedGetRemoteList assumes you have the lh lock
|
||||
// unlockedGetRemoteList
|
||||
// assumes you have the lh lock
|
||||
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
||||
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
|
||||
for i, addr := range allAddrs {
|
||||
am, ok := lh.addrMap[addr]
|
||||
if ok {
|
||||
if i != 0 {
|
||||
lh.addrMap[allAddrs[0]] = am
|
||||
}
|
||||
return am
|
||||
}
|
||||
}
|
||||
|
||||
am := NewRemoteList(allAddrs, lh.shouldAdd)
|
||||
am, ok := lh.addrMap[allAddrs[0]]
|
||||
if !ok {
|
||||
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
|
||||
for _, addr := range allAddrs {
|
||||
lh.addrMap[addr] = am
|
||||
}
|
||||
}
|
||||
return am
|
||||
}
|
||||
|
||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
||||
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
|
||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
|
||||
Trace("remoteAllowList.Allow")
|
||||
}
|
||||
if !allow {
|
||||
@@ -712,24 +698,21 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
||||
}
|
||||
|
||||
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
||||
l := lh.GetLighthouses()
|
||||
for i := range l {
|
||||
if l[i] == vpnAddr {
|
||||
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
||||
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
|
||||
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
|
||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
|
||||
l := lh.GetLighthouses()
|
||||
for i := range vpnAddrs {
|
||||
for j := range l {
|
||||
if l[j] == vpnAddrs[i] {
|
||||
for _, a := range vpnAddr {
|
||||
if _, ok := l[a]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -769,7 +752,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
||||
queried := 0
|
||||
lighthouses := lh.GetLighthouses()
|
||||
|
||||
for _, lhVpnAddr := range lighthouses {
|
||||
for lhVpnAddr := range lighthouses {
|
||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||
if hi != nil {
|
||||
v = hi.ConnectionState.myCert.Version()
|
||||
@@ -887,7 +870,7 @@ func (lh *LightHouse) SendUpdate() {
|
||||
updated := 0
|
||||
lighthouses := lh.GetLighthouses()
|
||||
|
||||
for _, lhVpnAddr := range lighthouses {
|
||||
for lhVpnAddr := range lighthouses {
|
||||
var v cert.Version
|
||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||
if hi != nil {
|
||||
@@ -945,6 +928,7 @@ func (lh *LightHouse) SendUpdate() {
|
||||
V4AddrPorts: v4,
|
||||
V6AddrPorts: v6,
|
||||
RelayVpnAddrs: relays,
|
||||
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1064,19 +1048,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
return
|
||||
}
|
||||
|
||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
useVersion := cert.Version1
|
||||
var queryVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
queryVpnAddr = netip.AddrFrom4(b)
|
||||
useVersion = 1
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
useVersion = 2
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
||||
Debugln("Dropping malformed HostQuery")
|
||||
}
|
||||
return
|
||||
}
|
||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1085,6 +1069,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostQueryReply
|
||||
if useVersion == cert.Version1 {
|
||||
if !queryVpnAddr.Is4() {
|
||||
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
|
||||
}
|
||||
b := queryVpnAddr.As4()
|
||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
||||
} else {
|
||||
@@ -1129,9 +1116,8 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
||||
if ok {
|
||||
whereToPunch = newDest
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
||||
}
|
||||
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
|
||||
//choosing to do nothing for now, but maybe we return an error?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1190,17 +1176,19 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
||||
if !r.Is4() {
|
||||
continue
|
||||
}
|
||||
|
||||
b = r.As4()
|
||||
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
||||
}
|
||||
|
||||
} else if v == cert.Version2 {
|
||||
for _, r := range c.relay.relay {
|
||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||
}
|
||||
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
||||
}
|
||||
//TODO: CERT-V2 don't panic
|
||||
panic("unsupported version")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1210,16 +1198,18 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
||||
return
|
||||
}
|
||||
|
||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
||||
}
|
||||
return
|
||||
lhh.lh.Lock()
|
||||
|
||||
var certVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
certVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
}
|
||||
relays := n.Details.GetRelays()
|
||||
|
||||
lhh.lh.Lock()
|
||||
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
||||
am.Lock()
|
||||
lhh.lh.Unlock()
|
||||
@@ -1244,24 +1234,27 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
return
|
||||
}
|
||||
|
||||
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
|
||||
var detailsVpnAddr netip.Addr
|
||||
var useVersion cert.Version
|
||||
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
|
||||
useVersion := cert.Version1
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
detailsVpnAddr = netip.AddrFrom4(b)
|
||||
useVersion = cert.Version1
|
||||
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
useVersion = cert.Version2
|
||||
} else {
|
||||
detailsVpnAddr = netip.Addr{}
|
||||
useVersion = cert.Version2
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
|
||||
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
|
||||
//Simple check that the host sent this not someone else
|
||||
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
||||
}
|
||||
@@ -1275,24 +1268,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
am.Lock()
|
||||
lhh.lh.Unlock()
|
||||
|
||||
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
||||
am.Unlock()
|
||||
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostUpdateNotificationAck
|
||||
switch useVersion {
|
||||
case cert.Version1:
|
||||
|
||||
if useVersion == cert.Version1 {
|
||||
if !fromVpnAddrs[0].Is4() {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
||||
return
|
||||
}
|
||||
vpnAddrB := fromVpnAddrs[0].As4()
|
||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
||||
case cert.Version2:
|
||||
// do nothing, we want to send a blank message
|
||||
default:
|
||||
} else if useVersion == cert.Version2 {
|
||||
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
|
||||
} else {
|
||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
||||
return
|
||||
}
|
||||
@@ -1310,20 +1303,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
||||
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
||||
//maybe one day we'll have a better idea, if it matters.
|
||||
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
||||
return
|
||||
}
|
||||
|
||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
empty := []byte{0}
|
||||
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
|
||||
punch := func(vpnPeer netip.AddrPort) {
|
||||
if !vpnPeer.IsValid() {
|
||||
return
|
||||
}
|
||||
@@ -1335,38 +1321,48 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
}()
|
||||
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
var logVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
logVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
}
|
||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
||||
}
|
||||
}
|
||||
|
||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
||||
for _, a := range n.Details.V4AddrPorts {
|
||||
b := protoV4AddrPortToNetAddrPort(a)
|
||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||
punch(b, detailsVpnAddr)
|
||||
}
|
||||
punch(protoV4AddrPortToNetAddrPort(a))
|
||||
}
|
||||
|
||||
for _, a := range n.Details.V6AddrPorts {
|
||||
b := protoV6AddrPortToNetAddrPort(a)
|
||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||
punch(b, detailsVpnAddr)
|
||||
}
|
||||
punch(protoV6AddrPortToNetAddrPort(a))
|
||||
}
|
||||
|
||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||
// of a double nat or other difficult scenario, this may help establish
|
||||
// a tunnel.
|
||||
if lhh.lh.punchy.GetRespond() {
|
||||
var queryVpnAddr netip.Addr
|
||||
if n.Details.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||
queryVpnAddr = netip.AddrFrom4(b)
|
||||
} else if n.Details.VpnAddr != nil {
|
||||
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
|
||||
}
|
||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||
// managed by a channel.
|
||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -1445,17 +1441,3 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
|
||||
}
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
|
||||
if d.OldVpnAddr != 0 {
|
||||
b := [4]byte{}
|
||||
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
|
||||
detailsVpnAddr := netip.AddrFrom4(b)
|
||||
return detailsVpnAddr, cert.Version1, nil
|
||||
} else if d.VpnAddr != nil {
|
||||
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
|
||||
return detailsVpnAddr, cert.Version2, nil
|
||||
} else {
|
||||
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestOldIPv4Only(t *testing.T) {
|
||||
@@ -493,123 +493,3 @@ func Test_findNetworkUnion(t *testing.T) {
|
||||
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
||||
|
||||
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
|
||||
testStaticHost := netip.MustParseAddr("10.128.0.42")
|
||||
//myVpnIp := netip.MustParseAddr("10.128.0.2")
|
||||
|
||||
c := config.NewC(l)
|
||||
lh1 := "10.128.0.2"
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"hosts": []any{lh1},
|
||||
"interval": "1s",
|
||||
}
|
||||
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
c.Settings["static_host_map"] = map[string]any{
|
||||
lh1: []any{"1.1.1.1:4242"},
|
||||
"10.128.0.42": []any{"1.2.3.4:4242"},
|
||||
}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
|
||||
//test that we actually have the static entry:
|
||||
out := lh.Query(testStaticHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
|
||||
out.Rebuild([]netip.Prefix{}) //why tho
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//bolt on a lower numbered primary IP
|
||||
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
|
||||
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
|
||||
lh.addrMap[testSameHostNotStatic] = am
|
||||
out.Rebuild([]netip.Prefix{}) //???
|
||||
|
||||
//test that we actually have the static entry:
|
||||
out = lh.Query(testStaticHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//test that we actually have the static entry for BOTH:
|
||||
out2 := lh.Query(testSameHostNotStatic)
|
||||
assert.Same(t, out2, out)
|
||||
|
||||
//now do the delete
|
||||
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
|
||||
//verify
|
||||
out = lh.Query(testSameHostNotStatic)
|
||||
assert.NotNil(t, out)
|
||||
if out == nil {
|
||||
t.Fatal("expected non-nil query for the static host")
|
||||
}
|
||||
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
||||
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
}
|
||||
|
||||
func TestLighthouse_DeletesWork(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
||||
testHost := netip.MustParseAddr("10.128.0.42")
|
||||
|
||||
c := config.NewC(l)
|
||||
lh1 := "10.128.0.2"
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"hosts": []any{lh1},
|
||||
"interval": "1s",
|
||||
}
|
||||
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
c.Settings["static_host_map"] = map[string]any{
|
||||
lh1: []any{"1.1.1.1:4242"},
|
||||
}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
|
||||
//insert the host
|
||||
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
|
||||
am.vpnAddrs = []netip.Addr{testHost}
|
||||
am.addrs = []netip.AddrPort{myUdpAddr2}
|
||||
lh.addrMap[testHost] = am
|
||||
am.Rebuild([]netip.Prefix{}) //???
|
||||
|
||||
//test that we actually have the entry:
|
||||
out := lh.Query(testHost)
|
||||
assert.NotNil(t, out)
|
||||
assert.Equal(t, out.vpnAddrs[0], testHost)
|
||||
out.Rebuild([]netip.Prefix{}) //why tho
|
||||
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||
|
||||
//now do the delete
|
||||
lh.DeleteVpnAddrs([]netip.Addr{testHost})
|
||||
//verify
|
||||
out = lh.Query(testHost)
|
||||
assert.Nil(t, out)
|
||||
}
|
||||
|
||||
37
main.go
37
main.go
@@ -5,8 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -15,7 +13,7 @@ import (
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type m = map[string]any
|
||||
@@ -29,10 +27,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
}()
|
||||
|
||||
if buildVersion == "" {
|
||||
buildVersion = moduleVersion()
|
||||
}
|
||||
|
||||
l := logger
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
@@ -81,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
if c.GetBool("sshd.enabled", false) {
|
||||
sshStart, err = configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
||||
sshStart = nil
|
||||
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,7 +185,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
|
||||
hostMap := NewHostMapFromConfig(l, c)
|
||||
punchy := NewPunchyFromConfig(l, c)
|
||||
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
|
||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
||||
@@ -228,6 +220,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
}
|
||||
|
||||
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
|
||||
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
|
||||
|
||||
ifConfig := &InterfaceConfig{
|
||||
HostMap: hostMap,
|
||||
Inside: tun,
|
||||
@@ -236,8 +231,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
Firewall: fw,
|
||||
ServeDns: serveDns,
|
||||
HandshakeManager: handshakeManager,
|
||||
connectionManager: connManager,
|
||||
lightHouse: lightHouse,
|
||||
checkInterval: time.Second * time.Duration(checkInterval),
|
||||
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
|
||||
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
|
||||
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
|
||||
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
||||
@@ -248,6 +244,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
version: buildVersion,
|
||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||
punchy: punchy,
|
||||
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
}
|
||||
@@ -287,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
var dnsStart func()
|
||||
if lightHouse.amLighthouse && serveDns {
|
||||
l.Debugln("Starting dns server")
|
||||
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
||||
dnsStart = dnsMain(ctx, l, pki.getCertState(), hostMap, c)
|
||||
}
|
||||
|
||||
return &Control{
|
||||
@@ -299,21 +296,5 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
statsStart,
|
||||
dnsStart,
|
||||
lightHouse.StartUpdateWorker,
|
||||
connManager.Start,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func moduleVersion() string {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, dep := range info.Deps {
|
||||
if dep.Path == "github.com/slackhq/nebula" {
|
||||
return strings.TrimPrefix(dep.Version, "v")
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
31
outside.go
31
outside.go
@@ -81,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
// Pull the Roaming parts up here, and return in all call paths.
|
||||
f.handleHostRoaming(hostinfo, ip)
|
||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||
f.connectionManager.In(hostinfo)
|
||||
f.connectionManager.In(hostinfo.localIndexId)
|
||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||
|
||||
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
||||
@@ -213,7 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
|
||||
f.handleHostRoaming(hostinfo, ip)
|
||||
|
||||
f.connectionManager.In(hostinfo)
|
||||
f.connectionManager.In(hostinfo.localIndexId)
|
||||
}
|
||||
|
||||
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
||||
@@ -254,18 +254,16 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
||||
|
||||
}
|
||||
|
||||
// handleEncrypted returns true if a packet should be processed, false otherwise
|
||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
||||
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
||||
if ci == nil {
|
||||
// If connectionstate exists and the replay protector allows, process packet
|
||||
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
|
||||
if addr.IsValid() {
|
||||
f.maybeSendRecvError(addr, h.RemoteIndex)
|
||||
}
|
||||
return false
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
// If the window check fails, refuse to process the packet, but don't send a recv error
|
||||
if !ci.window.Check(f.l, h.MessageCounter) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -314,11 +312,12 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
||||
next := 0
|
||||
for {
|
||||
if protoAt >= dataLen {
|
||||
if dataLen < offset {
|
||||
break
|
||||
}
|
||||
proto := layers.IPProtocol(data[protoAt])
|
||||
|
||||
proto := layers.IPProtocol(data[protoAt])
|
||||
//fmt.Println(proto, protoAt)
|
||||
switch proto {
|
||||
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||
fp.Protocol = uint8(proto)
|
||||
@@ -366,7 +365,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
|
||||
case layers.IPProtocolAH:
|
||||
// Auth headers, used by IPSec, have a different meaning for header length
|
||||
if dataLen <= offset+1 {
|
||||
if dataLen < offset+1 {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -374,7 +373,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||
|
||||
default:
|
||||
// Normal ipv6 header length processing
|
||||
if dataLen <= offset+1 {
|
||||
if dataLen < offset+1 {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -500,7 +499,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
return false
|
||||
}
|
||||
|
||||
f.connectionManager.In(hostinfo)
|
||||
f.connectionManager.In(hostinfo.localIndexId)
|
||||
_, err = f.readers[q].Write(out)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to write to tun")
|
||||
@@ -539,6 +538,10 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||
return
|
||||
}
|
||||
|
||||
if !hostinfo.RecvErrorExceeded() {
|
||||
return
|
||||
}
|
||||
|
||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||
return
|
||||
|
||||
@@ -117,45 +117,6 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
err = newPacket(buffer.Bytes(), true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A v6 packet with a hop-by-hop extension
|
||||
// ICMPv6 Payload (Echo Request)
|
||||
icmpLayer := layers.ICMPv6{
|
||||
TypeCode: layers.ICMPv6TypeEchoRequest,
|
||||
}
|
||||
// Hop-by-Hop Extension Header
|
||||
hopOption := layers.IPv6HopByHopOption{}
|
||||
hopOption.OptionData = []byte{0, 0, 0, 0}
|
||||
hopByHop := layers.IPv6HopByHop{}
|
||||
hopByHop.Options = append(hopByHop.Options, &hopOption)
|
||||
|
||||
ip = layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 128,
|
||||
NextHeader: layers.IPProtocolIPv6Destination,
|
||||
SrcIP: net.IPv6linklocalallrouters,
|
||||
DstIP: net.IPv6linklocalallnodes,
|
||||
}
|
||||
|
||||
buffer.Clear()
|
||||
err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{
|
||||
ComputeChecksums: false,
|
||||
FixLengths: true,
|
||||
}, &ip, &hopByHop, &icmpLayer)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Ensure buffer length checks during parsing with the next 2 tests.
|
||||
|
||||
// A full IPv6 header and 1 byte in the first extension, but missing
|
||||
// the length byte.
|
||||
err = newPacket(buffer.Bytes()[:41], true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A full IPv6 header plus 1 full extension, but only 1 byte of the
|
||||
// next layer, missing length byte
|
||||
err = newPacket(buffer.Bytes()[:49], true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A good ICMP packet
|
||||
ip = layers.IPv6{
|
||||
Version: 6,
|
||||
@@ -327,10 +288,6 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
assert.Equal(t, uint16(22), p.LocalPort)
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Ensure buffer bounds checking during processing
|
||||
err = newPacket(b[:41], true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
|
||||
// Invalid AH header
|
||||
b = buffer.Bytes()
|
||||
err = newPacket(b, true, p)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -72,51 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||
return addr
|
||||
}
|
||||
|
||||
func flipBytes(b []byte) []byte {
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] ^= 0xFF
|
||||
}
|
||||
return b
|
||||
}
|
||||
func orBytes(a []byte, b []byte) []byte {
|
||||
ret := make([]byte, len(a))
|
||||
for i := 0; i < len(a); i++ {
|
||||
ret[i] = a[i] | b[i]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
||||
broadcast, _ := netip.AddrFromSlice(
|
||||
orBytes(
|
||||
cidr.Addr().AsSlice(),
|
||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
||||
),
|
||||
)
|
||||
return broadcast
|
||||
}
|
||||
|
||||
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
||||
for _, gateway := range gateways {
|
||||
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
||||
return gateway, nil
|
||||
}
|
||||
|
||||
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
||||
return gateway, nil
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
@@ -294,6 +295,7 @@ func (t *tun) activate6(network netip.Prefix) error {
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
|
||||
Flags: _IN6_IFF_NODAD,
|
||||
}
|
||||
|
||||
@@ -552,3 +554,13 @@ func (t *tun) Name() string {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||
pLen := 128
|
||||
if prefix.Addr().Is4() {
|
||||
pLen = 32
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||
return addr
|
||||
}
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
@@ -20,18 +22,12 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||
FIODGNAME = 0x80106678
|
||||
TUNSIFMODE = 0x8004745e
|
||||
TUNSIFHEAD = 0x80047460
|
||||
OSIOCAIFADDR_IN6 = 0x8088691b
|
||||
IN6_IFF_NODAD = 0x0020
|
||||
)
|
||||
|
||||
type fiodgnameArg struct {
|
||||
@@ -41,159 +37,43 @@ type fiodgnameArg struct {
|
||||
}
|
||||
|
||||
type ifreqRename struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Name [16]byte
|
||||
Data uintptr
|
||||
}
|
||||
|
||||
type ifreqDestroy struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Name [16]byte
|
||||
pad [16]byte
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Flags uint16
|
||||
}
|
||||
|
||||
type ifreqMTU struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
MTU int32
|
||||
}
|
||||
|
||||
type addrLifetime struct {
|
||||
Expire uint64
|
||||
Preferred uint64
|
||||
Vltime uint32
|
||||
Pltime uint32
|
||||
}
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
VHid uint32
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime addrLifetime
|
||||
VHid uint32
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
devFd int
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
||||
if t.devFd < 0 {
|
||||
return -1, syscall.EINVAL
|
||||
}
|
||||
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
head := make([]byte, 4)
|
||||
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&to[0], uint64(len(to))},
|
||||
}
|
||||
|
||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
|
||||
var err error
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
// fix bytes read number to exclude header
|
||||
bytesRead := int(n)
|
||||
if bytesRead < 0 {
|
||||
return bytesRead, err
|
||||
} else if bytesRead < 4 {
|
||||
return 0, err
|
||||
} else {
|
||||
return bytesRead - 4, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
||||
if t.devFd < 0 {
|
||||
return -1, syscall.EINVAL
|
||||
}
|
||||
|
||||
if len(from) <= 1 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
ipVer := from[0] >> 4
|
||||
var head []byte
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
if ipVer == 4 {
|
||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
||||
} else if ipVer == 6 {
|
||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&from[0], uint64(len(from))},
|
||||
}
|
||||
|
||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
|
||||
var err error
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return int(n) - 4, err
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.devFd >= 0 {
|
||||
err := syscall.Close(t.devFd)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Error closing device")
|
||||
if t.ReadWriteCloser != nil {
|
||||
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
t.devFd = -1
|
||||
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
||||
defer close(c)
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err == nil {
|
||||
defer syscall.Close(s)
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
}
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Error destroying tunnel")
|
||||
return err
|
||||
}
|
||||
}()
|
||||
defer syscall.Close(s)
|
||||
|
||||
// wait up to 1 second so we start blocking at the ioctl
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
|
||||
// Destroy the interface
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -205,37 +85,32 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var fd int
|
||||
var file *os.File
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName != "" {
|
||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
}
|
||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||
// If the device doesn't already exist, request a new one and rename it
|
||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
||||
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the name of the interface
|
||||
rawConn, err := file.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SyscallConn: %v", err)
|
||||
}
|
||||
|
||||
var name [16]byte
|
||||
var ctrlErr error
|
||||
rawConn.Control(func(fd uintptr) {
|
||||
// Read the name of the interface
|
||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||
|
||||
if ctrlErr == nil {
|
||||
// set broadcast mode and multicast
|
||||
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
||||
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
||||
}
|
||||
|
||||
if ctrlErr == nil {
|
||||
// turn on link-layer mode, to support ipv6
|
||||
ifhead := uint32(1)
|
||||
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
||||
}
|
||||
|
||||
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||
})
|
||||
if ctrlErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -247,7 +122,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
|
||||
// If the name doesn't match the desired interface name, rename it now
|
||||
if ifName != deviceName {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
s, err := syscall.Socket(
|
||||
syscall.AF_INET,
|
||||
syscall.SOCK_DGRAM,
|
||||
syscall.IPPROTO_IP,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -270,11 +149,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
devFd: fd,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -293,111 +172,38 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
ifr := ifreqAlias4{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
},
|
||||
DstAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: getBroadcast(cidr).As4(),
|
||||
},
|
||||
MaskAddr: unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
},
|
||||
VHid: 0,
|
||||
}
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
if cidr.Addr().Is6() {
|
||||
ifr := ifreqAlias6{
|
||||
Name: t.deviceBytes(),
|
||||
Addr: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
},
|
||||
PrefixMask: unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
},
|
||||
Lifetime: addrLifetime{
|
||||
Expire: 0,
|
||||
Preferred: 0,
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
},
|
||||
Flags: IN6_IFF_NODAD,
|
||||
}
|
||||
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
// Setup our default MTU
|
||||
err := t.setMTU()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
linkAddr, err := getLinkAddr(t.Device)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if linkAddr == nil {
|
||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||
}
|
||||
t.linkAddr = linkAddr
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) setMTU() error {
|
||||
// Set the MTU on the device
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
||||
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
@@ -462,16 +268,15 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := addRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -484,8 +289,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := delRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
@@ -500,120 +306,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
fmt.Println("DOING CHANGE")
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
} else {
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: gateway,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLinkAddr Gets the link address for the interface of the given name
|
||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
switch m := m.(type) {
|
||||
case *netroute.InterfaceMessage:
|
||||
if m.Name == name {
|
||||
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
||||
if ok {
|
||||
return sa, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -38,7 +38,6 @@ type tun struct {
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
routeChan chan struct{}
|
||||
useSystemRoutes bool
|
||||
useSystemRoutesBufferSize int
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
@@ -130,7 +129,6 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
l: l,
|
||||
}
|
||||
|
||||
@@ -293,6 +291,7 @@ func (t *tun) addIPs(link netlink.Link) error {
|
||||
|
||||
//add all new addresses
|
||||
for i := range newAddrs {
|
||||
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
|
||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||
return err
|
||||
@@ -360,11 +359,6 @@ func (t *tun) Activate() error {
|
||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
}
|
||||
|
||||
const modeNone = 1
|
||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
||||
}
|
||||
|
||||
if err = t.addIPs(link); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -537,13 +531,7 @@ func (t *tun) watchRoutes() {
|
||||
rch := make(chan netlink.RouteUpdate)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
netlinkOptions := netlink.RouteSubscribeOptions{
|
||||
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
||||
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
||||
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
||||
}
|
||||
|
||||
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
|
||||
if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
|
||||
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
|
||||
return
|
||||
}
|
||||
@@ -553,14 +541,8 @@ func (t *tun) watchRoutes() {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case r, ok := <-rch:
|
||||
if ok {
|
||||
case r := <-rch:
|
||||
t.updateRoutes(r)
|
||||
} else {
|
||||
// may be should do something here as
|
||||
// netlink stops sending updates
|
||||
return
|
||||
}
|
||||
case <-doneChan:
|
||||
// netlink.RouteSubscriber will close the rch for us
|
||||
return
|
||||
@@ -642,11 +624,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Dst == nil {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
||||
return
|
||||
}
|
||||
|
||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
||||
|
||||
@@ -4,12 +4,13 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
@@ -19,42 +20,11 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
SIOCAIFADDR_IN6 = 0x8080696b
|
||||
TUNSIFHEAD = 0x80047442
|
||||
TUNSIFMODE = 0x80047458
|
||||
)
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
type ifreq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
data int
|
||||
}
|
||||
|
||||
type addrLifetime struct {
|
||||
Expire uint64
|
||||
Preferred uint64
|
||||
Vltime uint32
|
||||
Pltime uint32
|
||||
type ifreqDestroy struct {
|
||||
Name [16]byte
|
||||
pad [16]byte
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
@@ -64,18 +34,40 @@ type tun struct {
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
if err := t.ReadWriteCloser.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
@@ -85,19 +77,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
|
||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
@@ -119,225 +105,40 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.f != nil {
|
||||
if err := t.f.Close(); err != nil {
|
||||
return fmt.Errorf("error closing tun file: %w", err)
|
||||
}
|
||||
|
||||
// t.f.Close should have handled it for us but let's be extra sure
|
||||
_ = unix.Close(t.fd)
|
||||
|
||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ifr := ifreq{Name: t.deviceBytes()}
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
rc, err := t.f.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
||||
}
|
||||
|
||||
var errno syscall.Errno
|
||||
var n uintptr
|
||||
err = rc.Read(func(fd uintptr) bool {
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
head := [4]byte{}
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&to[0], uint64(len(to))},
|
||||
}
|
||||
|
||||
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
if errno.Temporary() {
|
||||
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
if err == syscall.EBADF || err.Error() == "use of closed file" {
|
||||
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
|
||||
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
|
||||
}
|
||||
|
||||
if errno != 0 {
|
||||
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
|
||||
}
|
||||
|
||||
// fix bytes read number to exclude header
|
||||
bytesRead := int(n)
|
||||
if bytesRead < 0 {
|
||||
return bytesRead, nil
|
||||
} else if bytesRead < 4 {
|
||||
return 0, nil
|
||||
} else {
|
||||
return bytesRead - 4, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
if len(from) <= 1 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
ipVer := from[0] >> 4
|
||||
var head [4]byte
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
if ipVer == 4 {
|
||||
head[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
head[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
rc, err := t.f.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var errno syscall.Errno
|
||||
var n uintptr
|
||||
err = rc.Write(func(fd uintptr) bool {
|
||||
iovecs := []syscall.Iovec{
|
||||
{&head[0], 4},
|
||||
{&from[0], uint64(len(from))},
|
||||
}
|
||||
|
||||
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||
// According to NetBSD documentation for TUN, writes will only return errors in which
|
||||
// this packet will never be delivered so just go on living life.
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
return int(n) - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
var req ifreqAlias4
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.DstAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.MaskAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
var err error
|
||||
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if cidr.Addr().Is6() {
|
||||
var req ifreqAlias6
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
}
|
||||
req.PrefixMask = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
}
|
||||
req.Lifetime = addrLifetime{
|
||||
Vltime: 0xffffffff,
|
||||
Pltime: 0xffffffff,
|
||||
}
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
mode := int32(unix.IFF_BROADCAST)
|
||||
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun device mode: %w", err)
|
||||
}
|
||||
|
||||
v := 1
|
||||
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun device head: %w", err)
|
||||
}
|
||||
|
||||
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||
}
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err = t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
func (t *tun) Activate() error {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
@@ -396,23 +197,21 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,8 +224,10 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
@@ -441,109 +242,3 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,50 +4,23 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
SIOCAIFADDR_IN6 = 0x8080691a
|
||||
)
|
||||
|
||||
type ifreqAlias4 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet4
|
||||
DstAddr unix.RawSockaddrInet4
|
||||
MaskAddr unix.RawSockaddrInet4
|
||||
}
|
||||
|
||||
type ifreqAlias6 struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
Addr unix.RawSockaddrInet6
|
||||
DstAddr unix.RawSockaddrInet6
|
||||
PrefixMask unix.RawSockaddrInet6
|
||||
Flags uint32
|
||||
Lifetime [2]uint32
|
||||
}
|
||||
|
||||
type ifreq struct {
|
||||
Name [unix.IFNAMSIZ]byte
|
||||
data int
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
@@ -55,42 +28,44 @@ type tun struct {
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
|
||||
io.ReadWriteCloser
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
return t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
if !deviceNameRE.MatchString(deviceName) {
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||
}
|
||||
|
||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if !deviceNameRE.MatchString(deviceName) {
|
||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||
}
|
||||
|
||||
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
@@ -112,154 +87,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.f != nil {
|
||||
if err := t.f.Close(); err != nil {
|
||||
return fmt.Errorf("error closing tun file: %w", err)
|
||||
}
|
||||
|
||||
// t.f.Close should have handled it for us but let's be extra sure
|
||||
_ = unix.Close(t.fd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.f.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
if cap(buf) < len(from)+4 {
|
||||
buf = make([]byte, len(from)+4)
|
||||
t.out = buf
|
||||
}
|
||||
buf = buf[:len(from)+4]
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
buf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
buf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
copy(buf[4:], from)
|
||||
|
||||
n, err := t.f.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
if cidr.Addr().Is4() {
|
||||
var req ifreqAlias4
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.DstAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: cidr.Addr().As4(),
|
||||
}
|
||||
req.MaskAddr = unix.RawSockaddrInet4{
|
||||
Len: unix.SizeofSockaddrInet4,
|
||||
Family: unix.AF_INET,
|
||||
Addr: prefixToMask(cidr).As4(),
|
||||
}
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||
}
|
||||
|
||||
err = addRoute(cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if cidr.Addr().Is6() {
|
||||
var req ifreqAlias6
|
||||
req.Name = t.deviceBytes()
|
||||
req.Addr = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: cidr.Addr().As16(),
|
||||
}
|
||||
req.PrefixMask = unix.RawSockaddrInet6{
|
||||
Len: unix.SizeofSockaddrInet6,
|
||||
Family: unix.AF_INET6,
|
||||
Addr: prefixToMask(cidr).As16(),
|
||||
}
|
||||
req.Lifetime[0] = 0xffffffff
|
||||
req.Lifetime[1] = 0xffffffff
|
||||
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown address type %v", cidr)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||
}
|
||||
|
||||
for i := range t.vpnNetworks {
|
||||
err = t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(s)
|
||||
|
||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
@@ -297,42 +124,63 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.addIp(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,9 +192,10 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
//TODO: CERT-V2 is this right?
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
@@ -355,115 +204,52 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
}
|
||||
return
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := &netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_ADD,
|
||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||
Seq: 1,
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.ReadWriteCloser.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
if cap(buf) < len(from)+4 {
|
||||
buf = make([]byte, len(from)+4)
|
||||
t.out = buf
|
||||
}
|
||||
buf = buf[:len(from)+4]
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
buf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
buf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
copy(buf[4:], from)
|
||||
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Try to do a change
|
||||
route.Type = unix.RTM_CHANGE
|
||||
data, err = route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
defer unix.Close(sock)
|
||||
|
||||
route := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||
}
|
||||
} else {
|
||||
gw, err := selectGateway(prefix, gateways)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Addrs = []netroute.Addr{
|
||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := route.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
n, err := t.ReadWriteCloser.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
@@ -180,7 +180,6 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
|
||||
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
||||
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
|
||||
}
|
||||
|
||||
// Set up the parameters which include the peer's public key
|
||||
|
||||
52
pki.go
52
pki.go
@@ -100,36 +100,41 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
currentState := p.cs.Load()
|
||||
if newState.v1Cert != nil {
|
||||
if currentState.v1Cert == nil {
|
||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||
} else {
|
||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new v1 cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
} else if currentState.v1Cert != nil {
|
||||
//TODO: CERT-V2 we should be able to tear this down
|
||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
if newState.v2Cert != nil {
|
||||
if currentState.v2Cert == nil {
|
||||
//adding certs is fine, actually
|
||||
} else {
|
||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
@@ -137,25 +142,13 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
} else if currentState.v2Cert != nil {
|
||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
||||
if newState.v1Cert == nil {
|
||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
||||
}
|
||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
@@ -180,6 +173,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
|
||||
p.cs.Store(newState)
|
||||
|
||||
//TODO: CERT-V2 newState needs a stringer that does json
|
||||
if initial {
|
||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||
} else {
|
||||
@@ -365,9 +359,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
||||
}
|
||||
|
||||
if v1.Networks()[0] != v2.Networks()[0] {
|
||||
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
|
||||
}
|
||||
//TODO: CERT-V2 make sure v2 has v1s address
|
||||
|
||||
cs.initiatingVersion = dv
|
||||
}
|
||||
@@ -523,14 +515,10 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||
}
|
||||
|
||||
bl := c.GetStringSlice("pki.blocklist", []string{})
|
||||
if len(bl) > 0 {
|
||||
for _, fp := range bl {
|
||||
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
||||
caPool.BlocklistFingerprint(fp)
|
||||
}
|
||||
|
||||
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
||||
}
|
||||
|
||||
return caPool, nil
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ type RemoteList struct {
|
||||
// The full list of vpn addresses assigned to this host
|
||||
vpnAddrs []netip.Addr
|
||||
|
||||
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
|
||||
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
||||
addrs []netip.AddrPort
|
||||
|
||||
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
||||
@@ -202,9 +202,7 @@ type RemoteList struct {
|
||||
cache map[netip.Addr]*cache
|
||||
|
||||
hr *hostnamesResults
|
||||
|
||||
// shouldAdd is a nillable function that decides if x should be added to addrs.
|
||||
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
|
||||
shouldAdd func(netip.Addr) bool
|
||||
|
||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||
// They should not be tried again during a handshake
|
||||
@@ -215,7 +213,7 @@ type RemoteList struct {
|
||||
}
|
||||
|
||||
// NewRemoteList creates a new empty RemoteList
|
||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
|
||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
|
||||
r := &RemoteList{
|
||||
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
||||
addrs: make([]netip.AddrPort, 0),
|
||||
@@ -370,15 +368,6 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
|
||||
return c
|
||||
}
|
||||
|
||||
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
|
||||
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
|
||||
r.Lock()
|
||||
r.badRemotes = nil
|
||||
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
|
||||
copy(r.vpnAddrs, vpnAddrs)
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
// ResetBlockedRemotes locks and clears the blocked remotes list
|
||||
func (r *RemoteList) ResetBlockedRemotes() {
|
||||
r.Lock()
|
||||
@@ -588,7 +577,7 @@ func (r *RemoteList) unlockedCollect() {
|
||||
|
||||
dnsAddrs := r.hr.GetAddrs()
|
||||
for _, addr := range dnsAddrs {
|
||||
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
|
||||
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
||||
if !r.unlockedIsBad(addr) {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
@@ -9,10 +9,13 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -43,7 +46,14 @@ type Service struct {
|
||||
}
|
||||
}
|
||||
|
||||
func New(control *nebula.Control) (*Service, error) {
|
||||
func New(config *config.C) (*Service, error) {
|
||||
logger := logrus.New()
|
||||
logger.Out = os.Stdout
|
||||
|
||||
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
control.Start()
|
||||
|
||||
ctx := control.Context()
|
||||
|
||||
@@ -5,19 +5,15 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type m = map[string]any
|
||||
@@ -75,15 +71,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
|
||||
panic(err)
|
||||
}
|
||||
|
||||
logger := logrus.New()
|
||||
logger.Out = os.Stdout
|
||||
|
||||
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s, err := New(control)
|
||||
s, err := New(&c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package udp
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")
|
||||
@@ -3,62 +3,20 @@
|
||||
|
||||
package udp
|
||||
|
||||
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type StdConn struct {
|
||||
*net.UDPConn
|
||||
isV4 bool
|
||||
sysFd uintptr
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
var _ Conn = &StdConn{}
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
lc := NewListenConfig(multi)
|
||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if uc, ok := pc.(*net.UDPConn); ok {
|
||||
c := &StdConn{UDPConn: uc, l: l}
|
||||
|
||||
rc, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open udp socket: %w", err)
|
||||
}
|
||||
|
||||
err = rc.Control(func(fd uintptr) {
|
||||
c.sysFd = fd
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get udp fd: %w", err)
|
||||
}
|
||||
|
||||
la, err := c.LocalAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.isV4 = la.Addr().Is4()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
func NewListenConfig(multi bool) net.ListenConfig {
|
||||
@@ -85,116 +43,16 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||
}
|
||||
}
|
||||
|
||||
//go:linkname sendto golang.org/x/sys/unix.sendto
|
||||
//go:noescape
|
||||
func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
|
||||
|
||||
func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
||||
var sa unsafe.Pointer
|
||||
var addrLen int32
|
||||
|
||||
if u.isV4 {
|
||||
if ap.Addr().Is6() {
|
||||
return ErrInvalidIPv6RemoteForSocket
|
||||
}
|
||||
|
||||
var rsa unix.RawSockaddrInet6
|
||||
rsa.Family = unix.AF_INET6
|
||||
rsa.Addr = ap.Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
||||
sa = unsafe.Pointer(&rsa)
|
||||
addrLen = syscall.SizeofSockaddrInet4
|
||||
} else {
|
||||
var rsa unix.RawSockaddrInet6
|
||||
rsa.Family = unix.AF_INET6
|
||||
rsa.Addr = ap.Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
||||
sa = unsafe.Pointer(&rsa)
|
||||
addrLen = syscall.SizeofSockaddrInet6
|
||||
}
|
||||
|
||||
// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
|
||||
// See https://github.com/golang/go/issues/73919
|
||||
for {
|
||||
//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
|
||||
err := sendto(int(u.sysFd), b, 0, sa, addrLen)
|
||||
if err == nil {
|
||||
// Written, get out before the error handling
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, syscall.EINTR) {
|
||||
// Write was interrupted, retry
|
||||
continue
|
||||
}
|
||||
|
||||
if errors.Is(err, syscall.EAGAIN) {
|
||||
return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
|
||||
}
|
||||
|
||||
if errors.Is(err, syscall.EBADF) {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
return &net.OpError{Op: "sendto", Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
switch v := a.(type) {
|
||||
case *net.UDPAddr:
|
||||
addr, ok := netip.AddrFromSlice(v.IP)
|
||||
if !ok {
|
||||
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
|
||||
}
|
||||
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
|
||||
|
||||
default:
|
||||
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||
// No UDP stats for non-linux
|
||||
return func() {}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOut(r EncReader) {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
for {
|
||||
// Just read one packet at a time
|
||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||
func (u *GenericConn) Rebind() error {
|
||||
rc, err := u.UDPConn.SyscallConn()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
}
|
||||
|
||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) Rebind() error {
|
||||
var err error
|
||||
if u.isV4 {
|
||||
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
|
||||
} else {
|
||||
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
|
||||
return err
|
||||
}
|
||||
|
||||
return rc.Control(func(fd uintptr) {
|
||||
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
|
||||
if err != nil {
|
||||
u.l.WithError(err).Error("Failed to rebind udp socket")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
//go:build (!linux || android) && !e2e_testing && !darwin
|
||||
//go:build (!linux || android) && !e2e_testing
|
||||
// +build !linux android
|
||||
// +build !e2e_testing
|
||||
// +build !darwin
|
||||
|
||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||
// means it can be used on platforms like Darwin and Windows.
|
||||
|
||||
@@ -221,7 +221,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
|
||||
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
||||
if !ip.Addr().Is4() {
|
||||
return ErrInvalidIPv6RemoteForSocket
|
||||
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||
}
|
||||
|
||||
var rsa unix.RawSockaddrInet4
|
||||
|
||||
@@ -92,25 +92,6 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
||||
// Enable v4 for this socket
|
||||
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
|
||||
|
||||
// Disable reporting of PORT_UNREACHABLE and NET_UNREACHABLE errors from the UDP socket receive call.
|
||||
// These errors are returned on Windows during UDP receives based on the receipt of ICMP packets. Disable
|
||||
// the UDP receive error returns with these ioctl calls.
|
||||
ret := uint32(0)
|
||||
flag := uint32(0)
|
||||
size := uint32(unsafe.Sizeof(flag))
|
||||
err = syscall.WSAIoctl(syscall.Handle(u.sock), syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ret = 0
|
||||
flag = 0
|
||||
size = uint32(unsafe.Sizeof(flag))
|
||||
SIO_UDP_NETRESET := uint32(syscall.IOC_IN | syscall.IOC_VENDOR | 15)
|
||||
err = syscall.WSAIoctl(syscall.Handle(u.sock), SIO_UDP_NETRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = u.rx.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -141,13 +122,9 @@ func (u *RIOConn) ListenOut(r EncReader) {
|
||||
// Just read one packet at a time
|
||||
n, rua, err := u.receive(buffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
}
|
||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||
continue
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user