mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
18 Commits
jay.wren-d
...
cert-v2-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8fe454972 | ||
|
|
6a96df18cc | ||
|
|
fbff6a1487 | ||
|
|
e4daed3563 | ||
|
|
1ad0f57c1e | ||
|
|
84dd988f01 | ||
|
|
8704047395 | ||
|
|
3f31517018 | ||
|
|
21a117a156 | ||
|
|
9d310e72c2 | ||
|
|
5380fef7b0 | ||
|
|
602dca8508 | ||
|
|
2b1a59c779 | ||
|
|
028d31c011 | ||
|
|
8adba3960b | ||
|
|
50850eeaf2 | ||
|
|
f30085eab8 | ||
|
|
f2c32421c4 |
2
.github/workflows/gofmt.yml
vendored
2
.github/workflows/gofmt.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Install goimports
|
- name: Install goimports
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -37,7 +37,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -70,12 +70,12 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Import certificates
|
- name: Import certificates
|
||||||
if: env.HAS_SIGNING_CREDS == 'true'
|
if: env.HAS_SIGNING_CREDS == 'true'
|
||||||
uses: Apple-Actions/import-codesign-certs@v5
|
uses: Apple-Actions/import-codesign-certs@v3
|
||||||
with:
|
with:
|
||||||
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
|
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
|
||||||
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
|
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
|
||||||
|
|||||||
3
.github/workflows/smoke-extra.yml
vendored
3
.github/workflows/smoke-extra.yml
vendored
@@ -27,9 +27,6 @@ jobs:
|
|||||||
go-version-file: 'go.mod'
|
go-version-file: 'go.mod'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: add hashicorp source
|
|
||||||
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
|
|
||||||
|
|
||||||
- name: install vagrant
|
- name: install vagrant
|
||||||
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
|
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/smoke.yml
vendored
2
.github/workflows/smoke.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: build
|
- name: build
|
||||||
|
|||||||
10
.github/workflows/smoke/build.sh
vendored
10
.github/workflows/smoke/build.sh
vendored
@@ -5,10 +5,6 @@ set -e -x
|
|||||||
rm -rf ./build
|
rm -rf ./build
|
||||||
mkdir ./build
|
mkdir ./build
|
||||||
|
|
||||||
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
|
|
||||||
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
|
|
||||||
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
|
|
||||||
|
|
||||||
(
|
(
|
||||||
cd build
|
cd build
|
||||||
|
|
||||||
@@ -25,16 +21,16 @@ NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{
|
|||||||
../genconfig.sh >lighthouse1.yml
|
../genconfig.sh >lighthouse1.yml
|
||||||
|
|
||||||
HOST="host2" \
|
HOST="host2" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||||
../genconfig.sh >host2.yml
|
../genconfig.sh >host2.yml
|
||||||
|
|
||||||
HOST="host3" \
|
HOST="host3" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||||
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
../genconfig.sh >host3.yml
|
../genconfig.sh >host3.yml
|
||||||
|
|
||||||
HOST="host4" \
|
HOST="host4" \
|
||||||
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||||
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
../genconfig.sh >host4.yml
|
../genconfig.sh >host4.yml
|
||||||
|
|
||||||
|
|||||||
34
.github/workflows/smoke/smoke-vagrant.sh
vendored
34
.github/workflows/smoke/smoke-vagrant.sh
vendored
@@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
|||||||
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
||||||
|
|
||||||
vagrant up
|
vagrant up
|
||||||
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
|
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test"
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" &
|
||||||
sleep 15
|
sleep 15
|
||||||
|
|
||||||
# grab tcpdump pcaps for debugging
|
# grab tcpdump pcaps for debugging
|
||||||
@@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host
|
|||||||
# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
||||||
# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
||||||
|
|
||||||
#docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
||||||
#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
|
vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
|
||||||
#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
||||||
#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
|
#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
|
||||||
|
|
||||||
@@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1
|
|||||||
# Should fail because not allowed by host3 inbound firewall
|
# Should fail because not allowed by host3 inbound firewall
|
||||||
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
|
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
|
||||||
|
|
||||||
#set +x
|
set +x
|
||||||
#echo
|
echo
|
||||||
#echo " *** Testing ncat from host2"
|
echo " *** Testing ncat from host2"
|
||||||
#echo
|
echo
|
||||||
#set -x
|
set -x
|
||||||
# Should fail because not allowed by host3 inbound firewall
|
# Should fail because not allowed by host3 inbound firewall
|
||||||
#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
|
#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
|
||||||
#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
|
#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
|
||||||
@@ -82,18 +82,18 @@ echo
|
|||||||
echo " *** Testing ping from host3"
|
echo " *** Testing ping from host3"
|
||||||
echo
|
echo
|
||||||
set -x
|
set -x
|
||||||
vagrant ssh -c "ping -c1 192.168.100.1" -- -T
|
vagrant ssh -c "ping -c1 192.168.100.1"
|
||||||
vagrant ssh -c "ping -c1 192.168.100.2" -- -T
|
vagrant ssh -c "ping -c1 192.168.100.2"
|
||||||
|
|
||||||
#set +x
|
set +x
|
||||||
#echo
|
echo
|
||||||
#echo " *** Testing ncat from host3"
|
echo " *** Testing ncat from host3"
|
||||||
#echo
|
echo
|
||||||
#set -x
|
set -x
|
||||||
#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
|
#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
|
||||||
#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
|
#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
|
||||||
|
|
||||||
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
|
vagrant ssh -c "sudo xargs kill </nebula/pid"
|
||||||
docker exec host2 sh -c 'kill 1'
|
docker exec host2 sh -c 'kill 1'
|
||||||
docker exec lighthouse1 sh -c 'kill 1'
|
docker exec lighthouse1 sh -c 'kill 1'
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|||||||
18
.github/workflows/test.yml
vendored
18
.github/workflows/test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -31,11 +31,6 @@ jobs:
|
|||||||
- name: Vet
|
- name: Vet
|
||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
|
||||||
uses: golangci/golangci-lint-action@v7
|
|
||||||
with:
|
|
||||||
version: v2.0
|
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: make test
|
run: make test
|
||||||
|
|
||||||
@@ -60,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -70,7 +65,7 @@ jobs:
|
|||||||
run: make test-boringcrypto
|
run: make test-boringcrypto
|
||||||
|
|
||||||
- name: End 2 end
|
- name: End 2 end
|
||||||
run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
|
run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
||||||
|
|
||||||
test-linux-pkcs11:
|
test-linux-pkcs11:
|
||||||
name: Build and test on linux with pkcs11
|
name: Build and test on linux with pkcs11
|
||||||
@@ -102,7 +97,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version: '1.22'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build nebula
|
- name: Build nebula
|
||||||
@@ -114,11 +109,6 @@ jobs:
|
|||||||
- name: Vet
|
- name: Vet
|
||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
|
||||||
uses: golangci/golangci-lint-action@v7
|
|
||||||
with:
|
|
||||||
version: v2.0
|
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: make test
|
run: make test
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
version: "2"
|
|
||||||
linters:
|
|
||||||
default: none
|
|
||||||
enable:
|
|
||||||
- testifylint
|
|
||||||
exclusions:
|
|
||||||
generated: lax
|
|
||||||
presets:
|
|
||||||
- comments
|
|
||||||
- common-false-positives
|
|
||||||
- legacy
|
|
||||||
- std-error-handling
|
|
||||||
paths:
|
|
||||||
- third_party$
|
|
||||||
- builtin$
|
|
||||||
- examples$
|
|
||||||
formatters:
|
|
||||||
exclusions:
|
|
||||||
generated: lax
|
|
||||||
paths:
|
|
||||||
- third_party$
|
|
||||||
- builtin$
|
|
||||||
- examples$
|
|
||||||
@@ -7,13 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
### Changed
|
|
||||||
|
|
||||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
|
||||||
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
|
||||||
`local_cidr` field. This is almost always the intended behavior. This flag is
|
|
||||||
deprecated and will be removed in a future release.
|
|
||||||
|
|
||||||
## [1.9.4] - 2024-09-09
|
## [1.9.4] - 2024-09-09
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
4
Makefile
4
Makefile
@@ -137,8 +137,6 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
|
|||||||
# boringcrypto
|
# boringcrypto
|
||||||
build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
||||||
build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
||||||
build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0
|
|
||||||
build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0
|
|
||||||
|
|
||||||
build/%/nebula: .FORCE
|
build/%/nebula: .FORCE
|
||||||
GOOS=$(firstword $(subst -, , $*)) \
|
GOOS=$(firstword $(subst -, , $*)) \
|
||||||
@@ -172,7 +170,7 @@ test:
|
|||||||
go test -v ./...
|
go test -v ./...
|
||||||
|
|
||||||
test-boringcrypto:
|
test-boringcrypto:
|
||||||
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./...
|
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
|
||||||
|
|
||||||
test-pkcs11:
|
test-pkcs11:
|
||||||
CGO_ENABLED=1 go test -v -tags pkcs11 ./...
|
CGO_ENABLED=1 go test -v -tags pkcs11 ./...
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
|
|||||||
|
|
||||||
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
||||||
|
|
||||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
|
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU).
|
||||||
|
|
||||||
## Supported Platforms
|
## Supported Platforms
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
|
|||||||
$ sudo apk add nebula
|
$ sudo apk add nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
|
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb)
|
||||||
```
|
```
|
||||||
$ brew install nebula
|
$ brew install nebula
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ type AllowListNameRule struct {
|
|||||||
|
|
||||||
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
|
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
|
||||||
var nameRules []AllowListNameRule
|
var nameRules []AllowListNameRule
|
||||||
handleKey := func(key string, value any) (bool, error) {
|
handleKey := func(key string, value interface{}) (bool, error) {
|
||||||
if key == "interfaces" {
|
if key == "interfaces" {
|
||||||
var err error
|
var err error
|
||||||
nameRules, err = getAllowListInterfaces(k, value)
|
nameRules, err = getAllowListInterfaces(k, value)
|
||||||
@@ -70,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo
|
|||||||
|
|
||||||
// If the handleKey func returns true, the rest of the parsing is skipped
|
// If the handleKey func returns true, the rest of the parsing is skipped
|
||||||
// for this key. This allows parsing of special values like `interfaces`.
|
// for this key. This allows parsing of special values like `interfaces`.
|
||||||
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
|
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
||||||
r := c.Get(k)
|
r := c.Get(k)
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -81,8 +81,8 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va
|
|||||||
|
|
||||||
// If the handleKey func returns true, the rest of the parsing is skipped
|
// If the handleKey func returns true, the rest of the parsing is skipped
|
||||||
// for this key. This allows parsing of special values like `interfaces`.
|
// for this key. This allows parsing of special values like `interfaces`.
|
||||||
func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
|
func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
||||||
rawMap, ok := raw.(map[string]any)
|
rawMap, ok := raw.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
||||||
}
|
}
|
||||||
@@ -100,7 +100,12 @@ func newAllowList(k string, raw any, handleKey func(key string, value any) (bool
|
|||||||
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
||||||
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
||||||
|
|
||||||
for rawCIDR, rawValue := range rawMap {
|
for rawKey, rawValue := range rawMap {
|
||||||
|
rawCIDR, ok := rawKey.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||||
|
}
|
||||||
|
|
||||||
if handleKey != nil {
|
if handleKey != nil {
|
||||||
handled, err := handleKey(rawCIDR, rawValue)
|
handled, err := handleKey(rawCIDR, rawValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -111,7 +116,7 @@ func newAllowList(k string, raw any, handleKey func(key string, value any) (bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
value, ok := config.AsBool(rawValue)
|
value, ok := rawValue.(bool)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
||||||
}
|
}
|
||||||
@@ -168,18 +173,22 @@ func newAllowList(k string, raw any, handleKey func(key string, value any) (bool
|
|||||||
return &AllowList{cidrTree: tree}, nil
|
return &AllowList{cidrTree: tree}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) {
|
func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
|
||||||
var nameRules []AllowListNameRule
|
var nameRules []AllowListNameRule
|
||||||
|
|
||||||
rawRules, ok := v.(map[string]any)
|
rawRules, ok := v.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
|
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
firstEntry := true
|
firstEntry := true
|
||||||
var allValues bool
|
var allValues bool
|
||||||
for name, rawAllow := range rawRules {
|
for rawName, rawAllow := range rawRules {
|
||||||
allow, ok := config.AsBool(rawAllow)
|
name, ok := rawName.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
|
||||||
|
}
|
||||||
|
allow, ok := rawAllow.(bool)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
|
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
|
||||||
}
|
}
|
||||||
@@ -215,11 +224,16 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
|
|||||||
|
|
||||||
remoteAllowRanges := new(bart.Table[*AllowList])
|
remoteAllowRanges := new(bart.Table[*AllowList])
|
||||||
|
|
||||||
rawMap, ok := value.(map[string]any)
|
rawMap, ok := value.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||||
}
|
}
|
||||||
for rawCIDR, rawValue := range rawMap {
|
for rawKey, rawValue := range rawMap {
|
||||||
|
rawCIDR, ok := rawKey.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||||
|
}
|
||||||
|
|
||||||
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
|
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -9,33 +9,32 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewAllowListFromConfig(t *testing.T) {
|
func TestNewAllowListFromConfig(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"192.168.0.0": true,
|
"192.168.0.0": true,
|
||||||
}
|
}
|
||||||
r, err := newAllowListFromConfig(c, "allowlist", nil)
|
r, err := newAllowListFromConfig(c, "allowlist", nil)
|
||||||
require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
|
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
|
||||||
assert.Nil(t, r)
|
assert.Nil(t, r)
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"192.168.0.0/16": "abc",
|
"192.168.0.0/16": "abc",
|
||||||
}
|
}
|
||||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"192.168.0.0/16": true,
|
"192.168.0.0/16": true,
|
||||||
"10.0.0.0/8": false,
|
"10.0.0.0/8": false,
|
||||||
}
|
}
|
||||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"0.0.0.0/0": true,
|
"0.0.0.0/0": true,
|
||||||
"10.0.0.0/8": false,
|
"10.0.0.0/8": false,
|
||||||
"10.42.42.0/24": true,
|
"10.42.42.0/24": true,
|
||||||
@@ -43,9 +42,9 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||||||
"fd00:fd00::/16": false,
|
"fd00:fd00::/16": false,
|
||||||
}
|
}
|
||||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"0.0.0.0/0": true,
|
"0.0.0.0/0": true,
|
||||||
"10.0.0.0/8": false,
|
"10.0.0.0/8": false,
|
||||||
"10.42.42.0/24": true,
|
"10.42.42.0/24": true,
|
||||||
@@ -55,7 +54,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||||||
assert.NotNil(t, r)
|
assert.NotNil(t, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"0.0.0.0/0": true,
|
"0.0.0.0/0": true,
|
||||||
"10.0.0.0/8": false,
|
"10.0.0.0/8": false,
|
||||||
"10.42.42.0/24": true,
|
"10.42.42.0/24": true,
|
||||||
@@ -70,25 +69,25 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test interface names
|
// Test interface names
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"interfaces": map[string]any{
|
"interfaces": map[interface{}]interface{}{
|
||||||
`docker.*`: "foo",
|
`docker.*`: "foo",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
|
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
|
||||||
require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"interfaces": map[string]any{
|
"interfaces": map[interface{}]interface{}{
|
||||||
`docker.*`: false,
|
`docker.*`: false,
|
||||||
`eth.*`: true,
|
`eth.*`: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
|
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
|
||||||
require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[string]any{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"interfaces": map[string]any{
|
"interfaces": map[interface{}]interface{}{
|
||||||
`docker.*`: false,
|
`docker.*`: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -99,7 +98,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAllowList_Allow(t *testing.T) {
|
func TestAllowList_Allow(t *testing.T) {
|
||||||
assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
|
assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
|
||||||
|
|
||||||
tree := new(bart.Table[bool])
|
tree := new(bart.Table[bool])
|
||||||
tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
|
tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
|
||||||
@@ -112,17 +111,17 @@ func TestAllowList_Allow(t *testing.T) {
|
|||||||
tree.Insert(netip.MustParsePrefix("::2/128"), false)
|
tree.Insert(netip.MustParsePrefix("::2/128"), false)
|
||||||
al := &AllowList{cidrTree: tree}
|
al := &AllowList{cidrTree: tree}
|
||||||
|
|
||||||
assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1")))
|
assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
|
||||||
assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4")))
|
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
|
||||||
assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42")))
|
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
|
||||||
assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41")))
|
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
|
||||||
assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1")))
|
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
|
||||||
assert.True(t, al.Allow(netip.MustParseAddr("::1")))
|
assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
|
||||||
assert.False(t, al.Allow(netip.MustParseAddr("::2")))
|
assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLocalAllowList_AllowName(t *testing.T) {
|
func TestLocalAllowList_AllowName(t *testing.T) {
|
||||||
assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0"))
|
assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0"))
|
||||||
|
|
||||||
rules := []AllowListNameRule{
|
rules := []AllowListNameRule{
|
||||||
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
|
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
|
||||||
@@ -130,9 +129,9 @@ func TestLocalAllowList_AllowName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
al := &LocalAllowList{nameRules: rules}
|
al := &LocalAllowList{nameRules: rules}
|
||||||
|
|
||||||
assert.False(t, al.AllowName("docker0"))
|
assert.Equal(t, false, al.AllowName("docker0"))
|
||||||
assert.False(t, al.AllowName("tun0"))
|
assert.Equal(t, false, al.AllowName("tun0"))
|
||||||
assert.True(t, al.AllowName("eth0"))
|
assert.Equal(t, true, al.AllowName("eth0"))
|
||||||
|
|
||||||
rules = []AllowListNameRule{
|
rules = []AllowListNameRule{
|
||||||
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
|
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
|
||||||
@@ -140,7 +139,7 @@ func TestLocalAllowList_AllowName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
al = &LocalAllowList{nameRules: rules}
|
al = &LocalAllowList{nameRules: rules}
|
||||||
|
|
||||||
assert.False(t, al.AllowName("docker0"))
|
assert.Equal(t, false, al.AllowName("docker0"))
|
||||||
assert.True(t, al.AllowName("eth0"))
|
assert.Equal(t, true, al.AllowName("eth0"))
|
||||||
assert.True(t, al.AllowName("ens5"))
|
assert.Equal(t, true, al.AllowName("ens5"))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err := netip.ParseAddr("10.0.10.182")
|
input, err := netip.ParseAddr("10.0.10.182")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected, err := netip.ParseAddr("192.168.1.182")
|
expected, err := netip.ParseAddr("192.168.1.182")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
|
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
|
||||||
|
|
||||||
@@ -28,10 +28,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
|
|
||||||
@@ -41,10 +41,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
|
|
||||||
@@ -54,10 +54,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewCAPoolFromBytes(t *testing.T) {
|
func TestNewCAPoolFromBytes(t *testing.T) {
|
||||||
@@ -83,32 +82,32 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
|
|||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewCAPoolFromPEM([]byte(noNewLines))
|
p, err := NewCAPoolFromPEM([]byte(noNewLines))
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||||
|
|
||||||
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
|
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||||
|
|
||||||
// expired cert, no valid certs
|
// expired cert, no valid certs
|
||||||
ppp, err := NewCAPoolFromPEM([]byte(expired))
|
ppp, err := NewCAPoolFromPEM([]byte(expired))
|
||||||
assert.Equal(t, ErrExpired, err)
|
assert.Equal(t, ErrExpired, err)
|
||||||
assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
|
assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
|
||||||
|
|
||||||
// expired cert, with valid certs
|
// expired cert, with valid certs
|
||||||
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
|
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
|
||||||
assert.Equal(t, ErrExpired, err)
|
assert.Equal(t, ErrExpired, err)
|
||||||
assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||||
assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
|
assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
|
||||||
assert.Len(t, pppp.CAs, 3)
|
assert.Equal(t, len(pppp.CAs), 3)
|
||||||
|
|
||||||
ppppp, err := NewCAPoolFromPEM([]byte(p256))
|
ppppp, err := NewCAPoolFromPEM([]byte(p256))
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
|
assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
|
||||||
assert.Len(t, ppppp.CAs, 1)
|
assert.Equal(t, len(ppppp.CAs), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Verify(t *testing.T) {
|
func TestCertificateV1_Verify(t *testing.T) {
|
||||||
@@ -116,21 +115,21 @@ func TestCertificateV1_Verify(t *testing.T) {
|
|||||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
require.NoError(t, caPool.AddCA(ca))
|
assert.NoError(t, caPool.AddCA(ca))
|
||||||
|
|
||||||
f, err := c.Fingerprint()
|
f, err := c.Fingerprint()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
caPool.BlocklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.EqualError(t, err, "certificate is in the block list")
|
assert.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||||
require.EqualError(t, err, "root certificate is expired")
|
assert.EqualError(t, err, "root certificate is expired")
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
@@ -139,11 +138,11 @@ func TestCertificateV1_Verify(t *testing.T) {
|
|||||||
// Test group assertion
|
// Test group assertion
|
||||||
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||||
@@ -151,9 +150,9 @@ func TestCertificateV1_Verify(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_VerifyP256(t *testing.T) {
|
func TestCertificateV1_VerifyP256(t *testing.T) {
|
||||||
@@ -161,21 +160,21 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
require.NoError(t, caPool.AddCA(ca))
|
assert.NoError(t, caPool.AddCA(ca))
|
||||||
|
|
||||||
f, err := c.Fingerprint()
|
f, err := c.Fingerprint()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
caPool.BlocklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.EqualError(t, err, "certificate is in the block list")
|
assert.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||||
require.EqualError(t, err, "root certificate is expired")
|
assert.EqualError(t, err, "root certificate is expired")
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||||
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
@@ -184,11 +183,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
// Test group assertion
|
// Test group assertion
|
||||||
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||||
@@ -197,7 +196,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Verify_IPs(t *testing.T) {
|
func TestCertificateV1_Verify_IPs(t *testing.T) {
|
||||||
@@ -206,11 +205,11 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
|
|||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||||
|
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
// ip is outside the network
|
// ip is outside the network
|
||||||
@@ -246,25 +245,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
|
|||||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches
|
// Exact matches
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches reversed
|
// Exact matches reversed
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches reversed with just 1
|
// Exact matches reversed with just 1
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
||||||
@@ -273,11 +272,11 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
|||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||||
|
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
// ip is outside the network
|
// ip is outside the network
|
||||||
@@ -312,27 +311,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
|||||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
||||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches
|
// Exact matches
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches reversed
|
// Exact matches reversed
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches reversed with just 1
|
// Exact matches reversed with just 1
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Verify(t *testing.T) {
|
func TestCertificateV2_Verify(t *testing.T) {
|
||||||
@@ -340,21 +339,21 @@ func TestCertificateV2_Verify(t *testing.T) {
|
|||||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
require.NoError(t, caPool.AddCA(ca))
|
assert.NoError(t, caPool.AddCA(ca))
|
||||||
|
|
||||||
f, err := c.Fingerprint()
|
f, err := c.Fingerprint()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
caPool.BlocklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.EqualError(t, err, "certificate is in the block list")
|
assert.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||||
require.EqualError(t, err, "root certificate is expired")
|
assert.EqualError(t, err, "root certificate is expired")
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
@@ -363,11 +362,11 @@ func TestCertificateV2_Verify(t *testing.T) {
|
|||||||
// Test group assertion
|
// Test group assertion
|
||||||
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||||
@@ -375,9 +374,9 @@ func TestCertificateV2_Verify(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_VerifyP256(t *testing.T) {
|
func TestCertificateV2_VerifyP256(t *testing.T) {
|
||||||
@@ -385,21 +384,21 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
require.NoError(t, caPool.AddCA(ca))
|
assert.NoError(t, caPool.AddCA(ca))
|
||||||
|
|
||||||
f, err := c.Fingerprint()
|
f, err := c.Fingerprint()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
caPool.BlocklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.EqualError(t, err, "certificate is in the block list")
|
assert.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||||
require.EqualError(t, err, "root certificate is expired")
|
assert.EqualError(t, err, "root certificate is expired")
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||||
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
@@ -408,11 +407,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
// Test group assertion
|
// Test group assertion
|
||||||
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||||
@@ -421,7 +420,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Verify_IPs(t *testing.T) {
|
func TestCertificateV2_Verify_IPs(t *testing.T) {
|
||||||
@@ -430,11 +429,11 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
|
|||||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||||
|
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
// ip is outside the network
|
// ip is outside the network
|
||||||
@@ -470,25 +469,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
|
|||||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches
|
// Exact matches
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches reversed
|
// Exact matches reversed
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches reversed with just 1
|
// Exact matches reversed with just 1
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
||||||
@@ -497,11 +496,11 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
|||||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||||
|
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
// ip is outside the network
|
// ip is outside the network
|
||||||
@@ -536,25 +535,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
|||||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
||||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches
|
// Exact matches
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches reversed
|
// Exact matches reversed
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Exact matches reversed with just 1
|
// Exact matches reversed with just 1
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
22
cert/cert.go
22
cert/cert.go
@@ -113,10 +113,10 @@ func (cc *CachedCertificate) String() string {
|
|||||||
return cc.Certificate.String()
|
return cc.Certificate.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recombine will attempt to unmarshal a certificate received in a handshake.
|
// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
|
||||||
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
|
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
|
||||||
// reassemble the actual certificate structure with that in mind.
|
// reassemble the actual certificate structure with that in mind.
|
||||||
func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) {
|
func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
|
||||||
if publicKey == nil {
|
if publicKey == nil {
|
||||||
return nil, ErrNoPeerStaticKey
|
return nil, ErrNoPeerStaticKey
|
||||||
}
|
}
|
||||||
@@ -125,15 +125,29 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
|
|||||||
return nil, ErrNoPayload
|
return nil, ErrNoPayload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cc, err := caPool.VerifyCertificate(time.Now(), c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("certificate validation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
|
||||||
var c Certificate
|
var c Certificate
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
switch v {
|
switch v {
|
||||||
// Implementations must ensure the result is a valid cert!
|
// Implementations must ensure the result is a valid cert!
|
||||||
case VersionPre1, Version1:
|
case VersionPre1, Version1:
|
||||||
c, err = unmarshalCertificateV1(rawCertBytes, publicKey)
|
c, err = unmarshalCertificateV1(b, publicKey)
|
||||||
case Version2:
|
case Version2:
|
||||||
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
c, err = unmarshalCertificateV2(b, publicKey, curve)
|
||||||
default:
|
default:
|
||||||
//TODO: CERT-V2 make a static var
|
//TODO: CERT-V2 make a static var
|
||||||
return nil, fmt.Errorf("unknown certificate version %d", v)
|
return nil, fmt.Errorf("unknown certificate version %d", v)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type detailsV1 struct {
|
|||||||
curve Curve
|
curve Curve
|
||||||
}
|
}
|
||||||
|
|
||||||
type m = map[string]any
|
type m map[string]interface{}
|
||||||
|
|
||||||
func (c *certificateV1) Version() Version {
|
func (c *certificateV1) Version() Version {
|
||||||
return Version1
|
return Version1
|
||||||
|
|||||||
@@ -39,14 +39,14 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, err := nc.Marshal()
|
b, err := nc.Marshal()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
//t.Log("Cert size:", len(b))
|
//t.Log("Cert size:", len(b))
|
||||||
|
|
||||||
nc2, err := unmarshalCertificateV1(b, nil)
|
nc2, err := unmarshalCertificateV1(b, nil)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
assert.Equal(t, Version1, nc.Version())
|
assert.Equal(t, nc.Version(), Version1)
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
assert.Equal(t, nc.Curve(), Curve_CURVE25519)
|
||||||
assert.Equal(t, nc.Signature(), nc2.Signature())
|
assert.Equal(t, nc.Signature(), nc2.Signature())
|
||||||
assert.Equal(t, nc.Name(), nc2.Name())
|
assert.Equal(t, nc.Name(), nc2.Name())
|
||||||
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
||||||
@@ -99,8 +99,8 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, err := nc.MarshalJSON()
|
b, err := nc.MarshalJSON()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.JSONEq(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
|
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
|
||||||
string(b),
|
string(b),
|
||||||
@@ -110,47 +110,47 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
|
|||||||
func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
|
func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
||||||
require.Error(t, err)
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, priv2 := X25519Keypair()
|
_, priv2 := X25519Keypair()
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
||||||
require.Error(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
|
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
||||||
require.Error(t, err)
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, priv2 := P256Keypair()
|
_, priv2 := P256Keypair()
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||||
require.Error(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that upgrading the protobuf library does not change how certificates
|
// Ensure that upgrading the protobuf library does not change how certificates
|
||||||
@@ -182,11 +182,11 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, err := nc.Marshal()
|
b, err := nc.Marshal()
|
||||||
require.NoError(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
|
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
|
||||||
|
|
||||||
b, err = proto.Marshal(nc.getRawDetails())
|
b, err = proto.Marshal(nc.getRawDetails())
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,7 +201,7 @@ func TestUnmarshalCertificateV1(t *testing.T) {
|
|||||||
// Test that we don't panic with an invalid certificate (#332)
|
// Test that we don't panic with an invalid certificate (#332)
|
||||||
data := []byte("\x98\x00\x00")
|
data := []byte("\x98\x00\x00")
|
||||||
_, err := unmarshalCertificateV1(data, nil)
|
_, err := unmarshalCertificateV1(data, nil)
|
||||||
require.EqualError(t, err, "encoded Details was nil")
|
assert.EqualError(t, err, "encoded Details was nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendByteSlices(b ...[]byte) []byte {
|
func appendByteSlices(b ...[]byte) []byte {
|
||||||
|
|||||||
@@ -45,14 +45,14 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
|||||||
nc.rawDetails = db
|
nc.rawDetails = db
|
||||||
|
|
||||||
b, err := nc.Marshal()
|
b, err := nc.Marshal()
|
||||||
require.NoError(t, err)
|
require.Nil(t, err)
|
||||||
//t.Log("Cert size:", len(b))
|
//t.Log("Cert size:", len(b))
|
||||||
|
|
||||||
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
|
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
assert.Equal(t, Version2, nc.Version())
|
assert.Equal(t, nc.Version(), Version2)
|
||||||
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
assert.Equal(t, nc.Curve(), Curve_CURVE25519)
|
||||||
assert.Equal(t, nc.Signature(), nc2.Signature())
|
assert.Equal(t, nc.Signature(), nc2.Signature())
|
||||||
assert.Equal(t, nc.Name(), nc2.Name())
|
assert.Equal(t, nc.Name(), nc2.Name())
|
||||||
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
||||||
@@ -114,15 +114,15 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, err := nc.MarshalJSON()
|
b, err := nc.MarshalJSON()
|
||||||
require.ErrorIs(t, err, ErrMissingDetails)
|
assert.ErrorIs(t, err, ErrMissingDetails)
|
||||||
|
|
||||||
rd, err := nc.details.Marshal()
|
rd, err := nc.details.Marshal()
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
nc.rawDetails = rd
|
nc.rawDetails = rd
|
||||||
b, err = nc.MarshalJSON()
|
b, err = nc.MarshalJSON()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.JSONEq(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
|
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
|
||||||
string(b),
|
string(b),
|
||||||
@@ -132,85 +132,85 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
|
|||||||
func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
|
func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
|
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
|
||||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
_, caKey2, err := ed25519.GenerateKey(rand.Reader)
|
_, caKey2, err := ed25519.GenerateKey(rand.Reader)
|
||||||
require.NoError(t, err)
|
require.Nil(t, err)
|
||||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
||||||
require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
||||||
|
|
||||||
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, priv2 := X25519Keypair()
|
_, priv2 := X25519Keypair()
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||||
require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
|
assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
|
||||||
|
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
||||||
require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
||||||
|
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
|
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
|
||||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
ac, ok := c.(*certificateV2)
|
ac, ok := c.(*certificateV2)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
ac.curve = Curve(99)
|
ac.curve = Curve(99)
|
||||||
err = c.VerifyPrivateKey(Curve(99), priv2)
|
err = c.VerifyPrivateKey(Curve(99), priv2)
|
||||||
require.EqualError(t, err, "invalid curve: 99")
|
assert.EqualError(t, err, "invalid curve: 99")
|
||||||
|
|
||||||
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
|
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
|
||||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
|
||||||
|
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv[:16])
|
err = c.VerifyPrivateKey(Curve_P256, priv[:16])
|
||||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv)
|
err = c.VerifyPrivateKey(Curve_P256, priv)
|
||||||
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
aCa, ok := ca2.(*certificateV2)
|
aCa, ok := ca2.(*certificateV2)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
aCa.curve = Curve(99)
|
aCa.curve = Curve(99)
|
||||||
err = aCa.VerifyPrivateKey(Curve(99), priv2)
|
err = aCa.VerifyPrivateKey(Curve(99), priv2)
|
||||||
require.EqualError(t, err, "invalid curve: 99")
|
assert.EqualError(t, err, "invalid curve: 99")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
|
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
||||||
require.Error(t, err)
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, priv2 := P256Keypair()
|
_, priv2 := P256Keypair()
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||||
require.Error(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Copy(t *testing.T) {
|
func TestCertificateV2_Copy(t *testing.T) {
|
||||||
@@ -223,7 +223,7 @@ func TestCertificateV2_Copy(t *testing.T) {
|
|||||||
func TestUnmarshalCertificateV2(t *testing.T) {
|
func TestUnmarshalCertificateV2(t *testing.T) {
|
||||||
data := []byte("\x98\x00\x00")
|
data := []byte("\x98\x00\x00")
|
||||||
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
|
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
|
||||||
require.EqualError(t, err, "bad wire format")
|
assert.EqualError(t, err, "bad wire format")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_marshalForSigningStability(t *testing.T) {
|
func TestCertificateV2_marshalForSigningStability(t *testing.T) {
|
||||||
|
|||||||
@@ -4,20 +4,19 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/crypto/argon2"
|
"golang.org/x/crypto/argon2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewArgon2Parameters(t *testing.T) {
|
func TestNewArgon2Parameters(t *testing.T) {
|
||||||
p := NewArgon2Parameters(64*1024, 4, 3)
|
p := NewArgon2Parameters(64*1024, 4, 3)
|
||||||
assert.Equal(t, &Argon2Parameters{
|
assert.EqualValues(t, &Argon2Parameters{
|
||||||
version: argon2.Version,
|
version: argon2.Version,
|
||||||
Memory: 64 * 1024,
|
Memory: 64 * 1024,
|
||||||
Parallelism: 4,
|
Parallelism: 4,
|
||||||
Iterations: 3,
|
Iterations: 3,
|
||||||
}, p)
|
}, p)
|
||||||
p = NewArgon2Parameters(2*1024*1024, 2, 1)
|
p = NewArgon2Parameters(2*1024*1024, 2, 1)
|
||||||
assert.Equal(t, &Argon2Parameters{
|
assert.EqualValues(t, &Argon2Parameters{
|
||||||
version: argon2.Version,
|
version: argon2.Version,
|
||||||
Memory: 2 * 1024 * 1024,
|
Memory: 2 * 1024 * 1024,
|
||||||
Parallelism: 2,
|
Parallelism: 2,
|
||||||
@@ -62,35 +61,35 @@ qrlJ69wer3ZUHFXA
|
|||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
|
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Len(t, k, 64)
|
assert.Len(t, k, 64)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
|
|
||||||
// Fail due to short key
|
// Fail due to short key
|
||||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||||
require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
|
assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||||
require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
|
assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
// it's missing the requisite pre-encapsulation boundary.
|
// it's missing the requisite pre-encapsulation boundary.
|
||||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
|
|
||||||
// Fail due to invalid passphrase
|
// Fail due to invalid passphrase
|
||||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
|
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
|
||||||
require.EqualError(t, err, "invalid passphrase or corrupt private key")
|
assert.EqualError(t, err, "invalid passphrase or corrupt private key")
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, []byte{}, rest)
|
assert.Equal(t, rest, []byte{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
|
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
|
||||||
@@ -100,14 +99,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
|
|||||||
bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
||||||
kdfParams := NewArgon2Parameters(64*1024, 4, 3)
|
kdfParams := NewArgon2Parameters(64*1024, 4, 3)
|
||||||
key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
|
key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Verify the "key" can be decrypted successfully
|
// Verify the "key" can be decrypted successfully
|
||||||
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
|
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
|
||||||
assert.Len(t, k, 64)
|
assert.Len(t, k, 64)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Equal(t, []byte{}, rest)
|
assert.Equal(t, rest, []byte{})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
|
// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
||||||
@@ -36,20 +35,20 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
|
|||||||
cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
|
cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
|
||||||
assert.NotNil(t, cert)
|
assert.NotNil(t, cert)
|
||||||
assert.Equal(t, rest, append(badBanner, invalidPem...))
|
assert.Equal(t, rest, append(badBanner, invalidPem...))
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Fail due to invalid banner.
|
// Fail due to invalid banner.
|
||||||
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
||||||
assert.Nil(t, cert)
|
assert.Nil(t, cert)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
require.EqualError(t, err, "bytes did not contain a proper certificate banner")
|
assert.EqualError(t, err, "bytes did not contain a proper certificate banner")
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
// it's missing the requisite pre-encapsulation boundary.
|
// it's missing the requisite pre-encapsulation boundary.
|
||||||
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
||||||
assert.Nil(t, cert)
|
assert.Nil(t, cert)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
|
func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
|
||||||
@@ -85,33 +84,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|||||||
assert.Len(t, k, 64)
|
assert.Len(t, k, 64)
|
||||||
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Fail due to short key
|
// Fail due to short key
|
||||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
|
assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
|
assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
// it's missing the requisite pre-encapsulation boundary.
|
// it's missing the requisite pre-encapsulation boundary.
|
||||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
|
func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
|
||||||
@@ -147,33 +146,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Fail due to short key
|
// Fail due to short key
|
||||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
|
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
require.EqualError(t, err, "bytes did not contain a proper private key banner")
|
assert.EqualError(t, err, "bytes did not contain a proper private key banner")
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
// it's missing the requisite pre-encapsulation boundary.
|
// it's missing the requisite pre-encapsulation boundary.
|
||||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||||
@@ -201,9 +200,9 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||||
assert.Len(t, k, 32)
|
assert.Equal(t, 32, len(k))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
|
|
||||||
// Fail due to short key
|
// Fail due to short key
|
||||||
@@ -211,13 +210,13 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
require.EqualError(t, err, "bytes did not contain a proper public key banner")
|
assert.EqualError(t, err, "bytes did not contain a proper public key banner")
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
@@ -226,7 +225,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||||
@@ -260,15 +259,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||||
assert.Len(t, k, 32)
|
assert.Equal(t, 32, len(k))
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Len(t, k, 65)
|
assert.Equal(t, 65, len(k))
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
|
|
||||||
@@ -276,12 +275,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
require.EqualError(t, err, "bytes did not contain a proper public key banner")
|
assert.EqualError(t, err, "bytes did not contain a proper public key banner")
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
@@ -289,5 +288,5 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV1_Sign(t *testing.T) {
|
func TestCertificateV1_Sign(t *testing.T) {
|
||||||
@@ -38,14 +37,14 @@ func TestCertificateV1_Sign(t *testing.T) {
|
|||||||
|
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
|
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, c)
|
assert.NotNil(t, c)
|
||||||
assert.True(t, c.CheckSignature(pub))
|
assert.True(t, c.CheckSignature(pub))
|
||||||
|
|
||||||
b, err := c.Marshal()
|
b, err := c.Marshal()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
uc, err := unmarshalCertificateV1(b, nil)
|
uc, err := unmarshalCertificateV1(b, nil)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, uc)
|
assert.NotNil(t, uc)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,18 +73,18 @@ func TestCertificateV1_SignP256(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
|
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
|
||||||
rawPriv := priv.D.FillBytes(make([]byte, 32))
|
rawPriv := priv.D.FillBytes(make([]byte, 32))
|
||||||
|
|
||||||
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
|
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, c)
|
assert.NotNil(t, c)
|
||||||
assert.True(t, c.CheckSignature(pub))
|
assert.True(t, c.CheckSignature(pub))
|
||||||
|
|
||||||
b, err := c.Marshal()
|
b, err := c.Marshal()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
uc, err := unmarshalCertificateV1(b, nil)
|
uc, err := unmarshalCertificateV1(b, nil)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, uc)
|
assert.NotNil(t, uc)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_caSummary(t *testing.T) {
|
func Test_caSummary(t *testing.T) {
|
||||||
@@ -90,75 +89,75 @@ func Test_ca(t *testing.T) {
|
|||||||
assertHelpError(t, ca(
|
assertHelpError(t, ca(
|
||||||
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
||||||
), "-name is required")
|
), "-name is required")
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// ipv4 only ips
|
// ipv4 only ips
|
||||||
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// ipv4 only subnets
|
// ipv4 only subnets
|
||||||
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// failed key write
|
// failed key write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
||||||
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// create temp key file
|
// create temp key file
|
||||||
keyF, err := os.CreateTemp("", "test.key")
|
keyF, err := os.CreateTemp("", "test.key")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
require.NoError(t, os.Remove(keyF.Name()))
|
assert.Nil(t, os.Remove(keyF.Name()))
|
||||||
|
|
||||||
// failed cert write
|
// failed cert write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
||||||
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// create temp cert file
|
// create temp cert file
|
||||||
crtF, err := os.CreateTemp("", "test.crt")
|
crtF, err := os.CreateTemp("", "test.crt")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
require.NoError(t, os.Remove(crtF.Name()))
|
assert.Nil(t, os.Remove(crtF.Name()))
|
||||||
require.NoError(t, os.Remove(keyF.Name()))
|
assert.Nil(t, os.Remove(keyF.Name()))
|
||||||
|
|
||||||
// test proper cert with removed empty groups and subnets
|
// test proper cert with removed empty groups and subnets
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
require.NoError(t, ca(args, ob, eb, nopw))
|
assert.Nil(t, ca(args, ob, eb, nopw))
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// read cert and key files
|
// read cert and key files
|
||||||
rb, _ := os.ReadFile(keyF.Name())
|
rb, _ := os.ReadFile(keyF.Name())
|
||||||
lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
|
lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, c)
|
assert.Equal(t, cert.Curve_CURVE25519, c)
|
||||||
assert.Empty(t, b)
|
assert.Len(t, b, 0)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, lKey, 64)
|
assert.Len(t, lKey, 64)
|
||||||
|
|
||||||
rb, _ = os.ReadFile(crtF.Name())
|
rb, _ = os.ReadFile(crtF.Name())
|
||||||
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
||||||
assert.Empty(t, b)
|
assert.Len(t, b, 0)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "test", lCrt.Name())
|
assert.Equal(t, "test", lCrt.Name())
|
||||||
assert.Empty(t, lCrt.Networks())
|
assert.Len(t, lCrt.Networks(), 0)
|
||||||
assert.True(t, lCrt.IsCA())
|
assert.True(t, lCrt.IsCA())
|
||||||
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
|
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
|
||||||
assert.Empty(t, lCrt.UnsafeNetworks())
|
assert.Len(t, lCrt.UnsafeNetworks(), 0)
|
||||||
assert.Len(t, lCrt.PublicKey(), 32)
|
assert.Len(t, lCrt.PublicKey(), 32)
|
||||||
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
|
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
|
||||||
assert.Empty(t, lCrt.Issuer())
|
assert.Equal(t, "", lCrt.Issuer())
|
||||||
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
|
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
|
||||||
|
|
||||||
// test encrypted key
|
// test encrypted key
|
||||||
@@ -167,15 +166,15 @@ func Test_ca(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
require.NoError(t, ca(args, ob, eb, testpw))
|
assert.Nil(t, ca(args, ob, eb, testpw))
|
||||||
assert.Equal(t, pwPromptOb, ob.String())
|
assert.Equal(t, pwPromptOb, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// read encrypted key file and verify default params
|
// read encrypted key file and verify default params
|
||||||
rb, _ = os.ReadFile(keyF.Name())
|
rb, _ = os.ReadFile(keyF.Name())
|
||||||
k, _ := pem.Decode(rb)
|
k, _ := pem.Decode(rb)
|
||||||
ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
|
ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
// we won't know salt in advance, so just check start of string
|
// we won't know salt in advance, so just check start of string
|
||||||
assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
|
assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
|
||||||
assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
|
assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
|
||||||
@@ -185,8 +184,8 @@ func Test_ca(t *testing.T) {
|
|||||||
var curve cert.Curve
|
var curve cert.Curve
|
||||||
curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
|
curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Len(t, b, 0)
|
||||||
assert.Len(t, lKey, 64)
|
assert.Len(t, lKey, 64)
|
||||||
|
|
||||||
// test when reading passsword results in an error
|
// test when reading passsword results in an error
|
||||||
@@ -195,9 +194,9 @@ func Test_ca(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
require.Error(t, ca(args, ob, eb, errpw))
|
assert.Error(t, ca(args, ob, eb, errpw))
|
||||||
assert.Equal(t, pwPromptOb, ob.String())
|
assert.Equal(t, pwPromptOb, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// test when user fails to enter a password
|
// test when user fails to enter a password
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
@@ -205,9 +204,9 @@ func Test_ca(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||||
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// create valid cert/key for overwrite tests
|
// create valid cert/key for overwrite tests
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
@@ -215,24 +214,24 @@ func Test_ca(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
require.NoError(t, ca(args, ob, eb, nopw))
|
assert.Nil(t, ca(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing certificate file
|
// test that we won't overwrite existing certificate file
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// test that we won't overwrite existing key file
|
// test that we won't overwrite existing key file
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_keygenSummary(t *testing.T) {
|
func Test_keygenSummary(t *testing.T) {
|
||||||
@@ -37,59 +36,59 @@ func Test_keygen(t *testing.T) {
|
|||||||
|
|
||||||
// required args
|
// required args
|
||||||
assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
|
assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
|
assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// failed key write
|
// failed key write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
|
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
|
||||||
require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// create temp key file
|
// create temp key file
|
||||||
keyF, err := os.CreateTemp("", "test.key")
|
keyF, err := os.CreateTemp("", "test.key")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(keyF.Name())
|
defer os.Remove(keyF.Name())
|
||||||
|
|
||||||
// failed pub write
|
// failed pub write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
|
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
|
||||||
require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
|
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// create temp pub file
|
// create temp pub file
|
||||||
pubF, err := os.CreateTemp("", "test.pub")
|
pubF, err := os.CreateTemp("", "test.pub")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(pubF.Name())
|
defer os.Remove(pubF.Name())
|
||||||
|
|
||||||
// test proper keygen
|
// test proper keygen
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
|
||||||
require.NoError(t, keygen(args, ob, eb))
|
assert.Nil(t, keygen(args, ob, eb))
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// read cert and key files
|
// read cert and key files
|
||||||
rb, _ := os.ReadFile(keyF.Name())
|
rb, _ := os.ReadFile(keyF.Name())
|
||||||
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||||
assert.Empty(t, b)
|
assert.Len(t, b, 0)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, lKey, 32)
|
assert.Len(t, lKey, 32)
|
||||||
|
|
||||||
rb, _ = os.ReadFile(pubF.Name())
|
rb, _ = os.ReadFile(pubF.Name())
|
||||||
lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
|
lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||||
assert.Empty(t, b)
|
assert.Len(t, b, 0)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, lPub, 32)
|
assert.Len(t, lPub, 32)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func (he *helpError) Error() string {
|
|||||||
return he.s
|
return he.s
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHelpErrorf(s string, v ...any) error {
|
func newHelpErrorf(s string, v ...interface{}) error {
|
||||||
return &helpError{s: fmt.Sprintf(s, v...)}
|
return &helpError{s: fmt.Sprintf(s, v...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_help(t *testing.T) {
|
func Test_help(t *testing.T) {
|
||||||
@@ -80,7 +79,7 @@ func assertHelpError(t *testing.T, err error, msg string) {
|
|||||||
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
|
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
require.EqualError(t, err, msg)
|
assert.EqualError(t, err, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func optionalPkcs11String(msg string) string {
|
func optionalPkcs11String(msg string) string {
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_printSummary(t *testing.T) {
|
func Test_printSummary(t *testing.T) {
|
||||||
@@ -43,30 +42,30 @@ func Test_printCert(t *testing.T) {
|
|||||||
|
|
||||||
// no path
|
// no path
|
||||||
err := printCert([]string{}, ob, eb)
|
err := printCert([]string{}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
assertHelpError(t, err, "-path is required")
|
assertHelpError(t, err, "-path is required")
|
||||||
|
|
||||||
// no cert at path
|
// no cert at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
|
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
|
assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
|
||||||
|
|
||||||
// invalid cert at path
|
// invalid cert at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
tf, err := os.CreateTemp("", "print-cert")
|
tf, err := os.CreateTemp("", "print-cert")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tf.Name())
|
defer os.Remove(tf.Name())
|
||||||
|
|
||||||
tf.WriteString("-----BEGIN NOPE-----")
|
tf.WriteString("-----BEGIN NOPE-----")
|
||||||
err = printCert([]string{"-path", tf.Name()}, ob, eb)
|
err = printCert([]string{"-path", tf.Name()}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
|
||||||
|
|
||||||
// test multiple certs
|
// test multiple certs
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
@@ -85,7 +84,7 @@ func Test_printCert(t *testing.T) {
|
|||||||
fp, _ := c.Fingerprint()
|
fp, _ := c.Fingerprint()
|
||||||
pk := hex.EncodeToString(c.PublicKey())
|
pk := hex.EncodeToString(c.PublicKey())
|
||||||
sig := hex.EncodeToString(c.Signature())
|
sig := hex.EncodeToString(c.Signature())
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
|
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
|
||||||
@@ -155,7 +154,7 @@ func Test_printCert(t *testing.T) {
|
|||||||
`,
|
`,
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// test json
|
// test json
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
@@ -170,14 +169,14 @@ func Test_printCert(t *testing.T) {
|
|||||||
fp, _ = c.Fingerprint()
|
fp, _ = c.Fingerprint()
|
||||||
pk = hex.EncodeToString(c.PublicKey())
|
pk = hex.EncodeToString(c.PublicKey())
|
||||||
sig = hex.EncodeToString(c.Signature())
|
sig = hex.EncodeToString(c.Signature())
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
|
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
|
||||||
`,
|
`,
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTestCaCert will generate a CA cert
|
// NewTestCaCert will generate a CA cert
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -104,17 +103,17 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
|
||||||
|
|
||||||
// failed to unmarshal key
|
// failed to unmarshal key
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
caKeyF, err := os.CreateTemp("", "sign-cert.key")
|
caKeyF, err := os.CreateTemp("", "sign-cert.key")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(caKeyF.Name())
|
defer os.Remove(caKeyF.Name())
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -126,7 +125,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
// failed to read cert
|
// failed to read cert
|
||||||
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -134,11 +133,11 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
|
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(caCrtF.Name())
|
defer os.Remove(caCrtF.Name())
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -149,7 +148,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
// failed to read pub
|
// failed to read pub
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -157,11 +156,11 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
inPubF, err := os.CreateTemp("", "in.pub")
|
inPubF, err := os.CreateTemp("", "in.pub")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(inPubF.Name())
|
defer os.Remove(inPubF.Name())
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -211,14 +210,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
// mismatched ca key
|
// mismatched ca key
|
||||||
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
|
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
|
caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(caKeyF2.Name())
|
defer os.Remove(caKeyF2.Name())
|
||||||
caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
|
caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
|
||||||
|
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -226,34 +225,34 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// create temp key file
|
// create temp key file
|
||||||
keyF, err := os.CreateTemp("", "test.key")
|
keyF, err := os.CreateTemp("", "test.key")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
|
|
||||||
// failed cert write
|
// failed cert write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
|
|
||||||
// create temp cert file
|
// create temp cert file
|
||||||
crtF, err := os.CreateTemp("", "test.crt")
|
crtF, err := os.CreateTemp("", "test.crt")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
|
|
||||||
// test proper cert with removed empty groups and subnets
|
// test proper cert with removed empty groups and subnets
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -261,14 +260,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
rb, _ := os.ReadFile(keyF.Name())
|
rb, _ := os.ReadFile(keyF.Name())
|
||||||
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||||
assert.Empty(t, b)
|
assert.Len(t, b, 0)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, lKey, 32)
|
assert.Len(t, lKey, 32)
|
||||||
|
|
||||||
rb, _ = os.ReadFile(crtF.Name())
|
rb, _ = os.ReadFile(crtF.Name())
|
||||||
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
||||||
assert.Empty(t, b)
|
assert.Len(t, b, 0)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "test", lCrt.Name())
|
assert.Equal(t, "test", lCrt.Name())
|
||||||
assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
|
assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
|
||||||
@@ -296,15 +295,15 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
|
||||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// read cert file and check pub key matches in-pub
|
// read cert file and check pub key matches in-pub
|
||||||
rb, _ = os.ReadFile(crtF.Name())
|
rb, _ = os.ReadFile(crtF.Name())
|
||||||
lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
|
lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
|
||||||
assert.Empty(t, b)
|
assert.Len(t, b, 0)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, lCrt.PublicKey(), inPub)
|
assert.Equal(t, lCrt.PublicKey(), inPub)
|
||||||
|
|
||||||
// test refuse to sign cert with duration beyond root
|
// test refuse to sign cert with duration beyond root
|
||||||
@@ -313,7 +312,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -321,14 +320,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing key file
|
// test that we won't overwrite existing key file
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -336,14 +335,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing certificate file
|
// test that we won't overwrite existing certificate file
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -356,11 +355,11 @@ func Test_signCert(t *testing.T) {
|
|||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
caKeyF, err = os.CreateTemp("", "sign-cert.key")
|
caKeyF, err = os.CreateTemp("", "sign-cert.key")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(caKeyF.Name())
|
defer os.Remove(caKeyF.Name())
|
||||||
|
|
||||||
caCrtF, err = os.CreateTemp("", "sign-cert.crt")
|
caCrtF, err = os.CreateTemp("", "sign-cert.crt")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(caCrtF.Name())
|
defer os.Remove(caCrtF.Name())
|
||||||
|
|
||||||
// generate the encrypted key
|
// generate the encrypted key
|
||||||
@@ -375,7 +374,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
// test with the proper password
|
// test with the proper password
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
assert.Nil(t, signCert(args, ob, eb, testpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -385,7 +384,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
testpw.password = []byte("invalid password")
|
testpw.password = []byte("invalid password")
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.Error(t, signCert(args, ob, eb, testpw))
|
assert.Error(t, signCert(args, ob, eb, testpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -394,7 +393,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.Error(t, signCert(args, ob, eb, nopw))
|
assert.Error(t, signCert(args, ob, eb, nopw))
|
||||||
// normally the user hitting enter on the prompt would add newlines between these
|
// normally the user hitting enter on the prompt would add newlines between these
|
||||||
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@@ -404,7 +403,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
require.Error(t, signCert(args, ob, eb, errpw))
|
assert.Error(t, signCert(args, ob, eb, errpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,33 +38,33 @@ func Test_verify(t *testing.T) {
|
|||||||
|
|
||||||
// required args
|
// required args
|
||||||
assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
|
assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
|
assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// no ca at path
|
// no ca at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
|
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
|
assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
|
||||||
|
|
||||||
// invalid ca at path
|
// invalid ca at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
caFile, err := os.CreateTemp("", "verify-ca")
|
caFile, err := os.CreateTemp("", "verify-ca")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(caFile.Name())
|
defer os.Remove(caFile.Name())
|
||||||
|
|
||||||
caFile.WriteString("-----BEGIN NOPE-----")
|
caFile.WriteString("-----BEGIN NOPE-----")
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
||||||
|
|
||||||
// make a ca for later
|
// make a ca for later
|
||||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
@@ -76,22 +76,22 @@ func Test_verify(t *testing.T) {
|
|||||||
|
|
||||||
// no crt at path
|
// no crt at path
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
|
assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
|
||||||
|
|
||||||
// invalid crt at path
|
// invalid crt at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
certFile, err := os.CreateTemp("", "verify-cert")
|
certFile, err := os.CreateTemp("", "verify-cert")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(certFile.Name())
|
defer os.Remove(certFile.Name())
|
||||||
|
|
||||||
certFile.WriteString("-----BEGIN NOPE-----")
|
certFile.WriteString("-----BEGIN NOPE-----")
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
|
||||||
|
|
||||||
// unverifiable cert at path
|
// unverifiable cert at path
|
||||||
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
@@ -106,9 +106,9 @@ func Test_verify(t *testing.T) {
|
|||||||
certFile.Write(b)
|
certFile.Write(b)
|
||||||
|
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
require.ErrorIs(t, err, cert.ErrSignatureMismatch)
|
assert.True(t, errors.Is(err, cert.ErrSignatureMismatch))
|
||||||
|
|
||||||
// verified cert at path
|
// verified cert at path
|
||||||
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
@@ -118,7 +118,7 @@ func Test_verify(t *testing.T) {
|
|||||||
certFile.Write(b)
|
certFile.Write(b)
|
||||||
|
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,14 +17,14 @@ import (
|
|||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type C struct {
|
type C struct {
|
||||||
path string
|
path string
|
||||||
files []string
|
files []string
|
||||||
Settings map[string]any
|
Settings map[interface{}]interface{}
|
||||||
oldSettings map[string]any
|
oldSettings map[interface{}]interface{}
|
||||||
callbacks []func(*C)
|
callbacks []func(*C)
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
reloadLock sync.Mutex
|
reloadLock sync.Mutex
|
||||||
@@ -32,7 +32,7 @@ type C struct {
|
|||||||
|
|
||||||
func NewC(l *logrus.Logger) *C {
|
func NewC(l *logrus.Logger) *C {
|
||||||
return &C{
|
return &C{
|
||||||
Settings: make(map[string]any),
|
Settings: make(map[interface{}]interface{}),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
nv any
|
nv interface{}
|
||||||
ov any
|
ov interface{}
|
||||||
)
|
)
|
||||||
|
|
||||||
if k == "" {
|
if k == "" {
|
||||||
@@ -147,7 +147,7 @@ func (c *C) ReloadConfig() {
|
|||||||
c.reloadLock.Lock()
|
c.reloadLock.Lock()
|
||||||
defer c.reloadLock.Unlock()
|
defer c.reloadLock.Unlock()
|
||||||
|
|
||||||
c.oldSettings = make(map[string]any)
|
c.oldSettings = make(map[interface{}]interface{})
|
||||||
for k, v := range c.Settings {
|
for k, v := range c.Settings {
|
||||||
c.oldSettings[k] = v
|
c.oldSettings[k] = v
|
||||||
}
|
}
|
||||||
@@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error {
|
|||||||
c.reloadLock.Lock()
|
c.reloadLock.Lock()
|
||||||
defer c.reloadLock.Unlock()
|
defer c.reloadLock.Unlock()
|
||||||
|
|
||||||
c.oldSettings = make(map[string]any)
|
c.oldSettings = make(map[interface{}]interface{})
|
||||||
for k, v := range c.Settings {
|
for k, v := range c.Settings {
|
||||||
c.oldSettings[k] = v
|
c.oldSettings[k] = v
|
||||||
}
|
}
|
||||||
@@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
rv, ok := r.([]any)
|
rv, ok := r.([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
@@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetMap will get the map for k or return the default d if not found or invalid
|
// GetMap will get the map for k or return the default d if not found or invalid
|
||||||
func (c *C) GetMap(k string, d map[string]any) map[string]any {
|
func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
|
||||||
r := c.Get(k)
|
r := c.Get(k)
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
v, ok := r.(map[string]any)
|
v, ok := r.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
@@ -243,7 +243,7 @@ func (c *C) GetInt(k string, d int) int {
|
|||||||
// GetUint32 will get the uint32 for k or return the default d if not found or invalid
|
// GetUint32 will get the uint32 for k or return the default d if not found or invalid
|
||||||
func (c *C) GetUint32(k string, d uint32) uint32 {
|
func (c *C) GetUint32(k string, d uint32) uint32 {
|
||||||
r := c.GetInt(k, int(d))
|
r := c.GetInt(k, int(d))
|
||||||
if r < 0 || uint64(r) > uint64(math.MaxUint32) {
|
if uint64(r) > uint64(math.MaxUint32) {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
return uint32(r)
|
return uint32(r)
|
||||||
@@ -266,22 +266,6 @@ func (c *C) GetBool(k string, d bool) bool {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func AsBool(v any) (value bool, ok bool) {
|
|
||||||
switch x := v.(type) {
|
|
||||||
case bool:
|
|
||||||
return x, true
|
|
||||||
case string:
|
|
||||||
switch x {
|
|
||||||
case "y", "yes":
|
|
||||||
return true, true
|
|
||||||
case "n", "no":
|
|
||||||
return false, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDuration will get the duration for k or return the default d if not found or invalid
|
// GetDuration will get the duration for k or return the default d if not found or invalid
|
||||||
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
||||||
r := c.GetString(k, "")
|
r := c.GetString(k, "")
|
||||||
@@ -292,7 +276,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) Get(k string) any {
|
func (c *C) Get(k string) interface{} {
|
||||||
return c.get(k, c.Settings)
|
return c.get(k, c.Settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,10 +284,10 @@ func (c *C) IsSet(k string) bool {
|
|||||||
return c.get(k, c.Settings) != nil
|
return c.get(k, c.Settings) != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) get(k string, v any) any {
|
func (c *C) get(k string, v interface{}) interface{} {
|
||||||
parts := strings.Split(k, ".")
|
parts := strings.Split(k, ".")
|
||||||
for _, p := range parts {
|
for _, p := range parts {
|
||||||
m, ok := v.(map[string]any)
|
m, ok := v.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -362,7 +346,7 @@ func (c *C) addFile(path string, direct bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) parseRaw(b []byte) error {
|
func (c *C) parseRaw(b []byte) error {
|
||||||
var m map[string]any
|
var m map[interface{}]interface{}
|
||||||
|
|
||||||
err := yaml.Unmarshal(b, &m)
|
err := yaml.Unmarshal(b, &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -374,7 +358,7 @@ func (c *C) parseRaw(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) parse() error {
|
func (c *C) parse() error {
|
||||||
var m map[string]any
|
var m map[interface{}]interface{}
|
||||||
|
|
||||||
for _, path := range c.files {
|
for _, path := range c.files {
|
||||||
b, err := os.ReadFile(path)
|
b, err := os.ReadFile(path)
|
||||||
@@ -382,7 +366,7 @@ func (c *C) parse() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var nm map[string]any
|
var nm map[interface{}]interface{}
|
||||||
err = yaml.Unmarshal(b, &nm)
|
err = yaml.Unmarshal(b, &nm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
@@ -19,20 +19,20 @@ func TestConfig_Load(t *testing.T) {
|
|||||||
// invalid yaml
|
// invalid yaml
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
||||||
require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}")
|
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
||||||
|
|
||||||
// simple multi config merge
|
// simple multi config merge
|
||||||
c = NewC(l)
|
c = NewC(l)
|
||||||
os.RemoveAll(dir)
|
os.RemoveAll(dir)
|
||||||
os.Mkdir(dir, 0755)
|
os.Mkdir(dir, 0755)
|
||||||
|
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||||
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
|
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
|
||||||
require.NoError(t, c.Load(dir))
|
assert.Nil(t, c.Load(dir))
|
||||||
expected := map[string]any{
|
expected := map[interface{}]interface{}{
|
||||||
"outer": map[string]any{
|
"outer": map[interface{}]interface{}{
|
||||||
"inner": "override",
|
"inner": "override",
|
||||||
},
|
},
|
||||||
"new": "hi",
|
"new": "hi",
|
||||||
@@ -44,12 +44,12 @@ func TestConfig_Get(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
// test simple type
|
// test simple type
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["firewall"] = map[string]any{"outbound": "hi"}
|
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
||||||
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
||||||
|
|
||||||
// test complex type
|
// test complex type
|
||||||
inner := []map[string]any{{"port": "1", "code": "2"}}
|
inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
|
||||||
c.Settings["firewall"] = map[string]any{"outbound": inner}
|
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
|
||||||
assert.EqualValues(t, inner, c.Get("firewall.outbound"))
|
assert.EqualValues(t, inner, c.Get("firewall.outbound"))
|
||||||
|
|
||||||
// test missing
|
// test missing
|
||||||
@@ -59,7 +59,7 @@ func TestConfig_Get(t *testing.T) {
|
|||||||
func TestConfig_GetStringSlice(t *testing.T) {
|
func TestConfig_GetStringSlice(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["slice"] = []any{"one", "two"}
|
c.Settings["slice"] = []interface{}{"one", "two"}
|
||||||
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["bool"] = true
|
c.Settings["bool"] = true
|
||||||
assert.True(t, c.GetBool("bool", false))
|
assert.Equal(t, true, c.GetBool("bool", false))
|
||||||
|
|
||||||
c.Settings["bool"] = "true"
|
c.Settings["bool"] = "true"
|
||||||
assert.True(t, c.GetBool("bool", false))
|
assert.Equal(t, true, c.GetBool("bool", false))
|
||||||
|
|
||||||
c.Settings["bool"] = false
|
c.Settings["bool"] = false
|
||||||
assert.False(t, c.GetBool("bool", true))
|
assert.Equal(t, false, c.GetBool("bool", true))
|
||||||
|
|
||||||
c.Settings["bool"] = "false"
|
c.Settings["bool"] = "false"
|
||||||
assert.False(t, c.GetBool("bool", true))
|
assert.Equal(t, false, c.GetBool("bool", true))
|
||||||
|
|
||||||
c.Settings["bool"] = "Y"
|
c.Settings["bool"] = "Y"
|
||||||
assert.True(t, c.GetBool("bool", false))
|
assert.Equal(t, true, c.GetBool("bool", false))
|
||||||
|
|
||||||
c.Settings["bool"] = "yEs"
|
c.Settings["bool"] = "yEs"
|
||||||
assert.True(t, c.GetBool("bool", false))
|
assert.Equal(t, true, c.GetBool("bool", false))
|
||||||
|
|
||||||
c.Settings["bool"] = "N"
|
c.Settings["bool"] = "N"
|
||||||
assert.False(t, c.GetBool("bool", true))
|
assert.Equal(t, false, c.GetBool("bool", true))
|
||||||
|
|
||||||
c.Settings["bool"] = "nO"
|
c.Settings["bool"] = "nO"
|
||||||
assert.False(t, c.GetBool("bool", true))
|
assert.Equal(t, false, c.GetBool("bool", true))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_HasChanged(t *testing.T) {
|
func TestConfig_HasChanged(t *testing.T) {
|
||||||
@@ -101,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) {
|
|||||||
// Test key change
|
// Test key change
|
||||||
c = NewC(l)
|
c = NewC(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
c.oldSettings = map[string]any{"test": "no"}
|
c.oldSettings = map[interface{}]interface{}{"test": "no"}
|
||||||
assert.True(t, c.HasChanged("test"))
|
assert.True(t, c.HasChanged("test"))
|
||||||
assert.True(t, c.HasChanged(""))
|
assert.True(t, c.HasChanged(""))
|
||||||
|
|
||||||
// No key change
|
// No key change
|
||||||
c = NewC(l)
|
c = NewC(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
c.oldSettings = map[string]any{"test": "hi"}
|
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
|
||||||
assert.False(t, c.HasChanged("test"))
|
assert.False(t, c.HasChanged("test"))
|
||||||
assert.False(t, c.HasChanged(""))
|
assert.False(t, c.HasChanged(""))
|
||||||
}
|
}
|
||||||
@@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
done := make(chan bool, 1)
|
done := make(chan bool, 1)
|
||||||
dir, err := os.MkdirTemp("", "config-test")
|
dir, err := os.MkdirTemp("", "config-test")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||||
|
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
require.NoError(t, c.Load(dir))
|
assert.Nil(t, c.Load(dir))
|
||||||
|
|
||||||
assert.False(t, c.HasChanged("outer.inner"))
|
assert.False(t, c.HasChanged("outer.inner"))
|
||||||
assert.False(t, c.HasChanged("outer"))
|
assert.False(t, c.HasChanged("outer"))
|
||||||
@@ -184,11 +184,11 @@ firewall:
|
|||||||
`),
|
`),
|
||||||
}
|
}
|
||||||
|
|
||||||
var m map[string]any
|
var m map[any]any
|
||||||
|
|
||||||
// merge the same way config.parse() merges
|
// merge the same way config.parse() merges
|
||||||
for _, b := range configs {
|
for _, b := range configs {
|
||||||
var nm map[string]any
|
var nm map[any]any
|
||||||
err := yaml.Unmarshal(b, &nm)
|
err := yaml.Unmarshal(b, &nm)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -205,15 +205,15 @@ firewall:
|
|||||||
t.Logf("Merged Config as YAML:\n%s", mYaml)
|
t.Logf("Merged Config as YAML:\n%s", mYaml)
|
||||||
|
|
||||||
// If a bug is present, some items might be replaced instead of merged like we expect
|
// If a bug is present, some items might be replaced instead of merged like we expect
|
||||||
expected := map[string]any{
|
expected := map[any]any{
|
||||||
"firewall": map[string]any{
|
"firewall": map[any]any{
|
||||||
"inbound": []any{
|
"inbound": []any{
|
||||||
map[string]any{"host": "any", "port": "any", "proto": "icmp"},
|
map[any]any{"host": "any", "port": "any", "proto": "icmp"},
|
||||||
map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
|
map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
|
||||||
map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
|
map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
|
||||||
"outbound": []any{
|
"outbound": []any{
|
||||||
map[string]any{"host": "any", "port": "any", "proto": "any"}}},
|
map[any]any{"host": "any", "port": "any", "proto": "any"}}},
|
||||||
"listen": map[string]any{
|
"listen": map[any]any{
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": 4242,
|
"port": 4242,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -498,7 +498,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
cs := n.intf.pki.getCertState()
|
cs := n.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
myCrt := cs.getCertificate(curCrt.Version())
|
myCrt := cs.getCertificate(curCrt.Version())
|
||||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestLighthouse() *LightHouse {
|
func newTestLighthouse() *LightHouse {
|
||||||
@@ -44,10 +43,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
initiatingVersion: cert.Version1,
|
defaultVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
@@ -126,10 +125,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
initiatingVersion: cert.Version1,
|
defaultVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
@@ -224,9 +223,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
|
caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
ncp := cert.NewCAPool()
|
ncp := cert.NewCAPool()
|
||||||
require.NoError(t, ncp.AddCA(caCert))
|
assert.NoError(t, ncp.AddCA(caCert))
|
||||||
|
|
||||||
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
|
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
tbs = &cert.TBSCertificate{
|
tbs = &cert.TBSCertificate{
|
||||||
@@ -238,7 +237,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
PublicKey: pubCrt,
|
PublicKey: pubCrt,
|
||||||
}
|
}
|
||||||
peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
|
peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
||||||
|
|
||||||
|
|||||||
@@ -131,7 +131,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
|
|||||||
|
|
||||||
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
|
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
|
||||||
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
|
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
|
||||||
if c.f.myVpnAddrsTable.Contains(vpnIp) {
|
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
|
||||||
|
if found {
|
||||||
// Only returning the default certificate since its impossible
|
// Only returning the default certificate since its impossible
|
||||||
// for any other host but ourselves to have more than 1
|
// for any other host but ourselves to have more than 1
|
||||||
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
|
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
|
|
||||||
// Make sure we don't have any unexpected fields
|
// Make sure we don't have any unexpected fields
|
||||||
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
||||||
assert.Equal(t, &expectedInfo, thi)
|
assert.EqualValues(t, &expectedInfo, thi)
|
||||||
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||||
|
|
||||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||||
@@ -110,7 +110,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertFields(t *testing.T, expected []string, actualStruct any) {
|
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
|
||||||
val := reflect.ValueOf(actualStruct).Elem()
|
val := reflect.ValueOf(actualStruct).Elem()
|
||||||
fields := make([]string, val.NumField())
|
fields := make([]string, val.NumField())
|
||||||
for i := 0; i < val.NumField(); i++ {
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -27,7 +26,7 @@ type dnsRecords struct {
|
|||||||
dnsMap4 map[string]netip.Addr
|
dnsMap4 map[string]netip.Addr
|
||||||
dnsMap6 map[string]netip.Addr
|
dnsMap6 map[string]netip.Addr
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
myVpnAddrsTable *bart.Lite
|
myVpnAddrsTable *bart.Table[struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||||
@@ -40,7 +39,7 @@ func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecord
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) query(q uint16, data string) netip.Addr {
|
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||||
data = strings.ToLower(data)
|
data = strings.ToLower(data)
|
||||||
d.RLock()
|
d.RLock()
|
||||||
defer d.RUnlock()
|
defer d.RUnlock()
|
||||||
@@ -58,7 +57,7 @@ func (d *dnsRecords) query(q uint16, data string) netip.Addr {
|
|||||||
return netip.Addr{}
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) queryCert(data string) string {
|
func (d *dnsRecords) QueryCert(data string) string {
|
||||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
@@ -113,8 +112,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
//if we found it in this table, it's good
|
_, found := d.myVpnAddrsTable.Lookup(b)
|
||||||
return d.myVpnAddrsTable.Contains(b)
|
return found //if we found it in this table, it's good
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||||
@@ -123,7 +122,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
case dns.TypeA, dns.TypeAAAA:
|
case dns.TypeA, dns.TypeAAAA:
|
||||||
qType := dns.TypeToString[q.Qtype]
|
qType := dns.TypeToString[q.Qtype]
|
||||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||||
ip := d.query(q.Qtype, q.Name)
|
ip := d.Query(q.Qtype, q.Name)
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -136,7 +135,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
d.l.Debugf("Query for TXT %s", q.Name)
|
d.l.Debugf("Query for TXT %s", q.Name)
|
||||||
ip := d.queryCert(q.Name)
|
ip := d.QueryCert(q.Name)
|
||||||
if ip != "" {
|
if ip != "" {
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -164,18 +163,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||||
dnsR = newDnsRecords(l, cs, hostMap)
|
dnsR = newDnsRecords(l, cs, hostMap)
|
||||||
|
|
||||||
// attach request handler func
|
// attach request handler func
|
||||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
reloadDns(ctx, l, c)
|
reloadDns(l, c)
|
||||||
})
|
})
|
||||||
|
|
||||||
return func() {
|
return func() {
|
||||||
startDns(ctx, l, c)
|
startDns(l, c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,24 +187,24 @@ func getDnsServerAddr(c *config.C) string {
|
|||||||
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func startDns(ctx context.Context, l *logrus.Logger, c *config.C) {
|
func startDns(l *logrus.Logger, c *config.C) {
|
||||||
dnsAddr = getDnsServerAddr(c)
|
dnsAddr = getDnsServerAddr(c)
|
||||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
||||||
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
||||||
err := dnsServer.ListenAndServe()
|
err := dnsServer.ListenAndServe()
|
||||||
defer dnsServer.ShutdownContext(ctx)
|
defer dnsServer.Shutdown()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorf("Failed to start server: %s\n ", err.Error())
|
l.Errorf("Failed to start server: %s\n ", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) {
|
func reloadDns(l *logrus.Logger, c *config.C) {
|
||||||
if dnsAddr == getDnsServerAddr(c) {
|
if dnsAddr == getDnsServerAddr(c) {
|
||||||
l.Debug("No DNS server config change detected")
|
l.Debug("No DNS server config change detected")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Debug("Restarting DNS server")
|
l.Debug("Restarting DNS server")
|
||||||
dnsServer.ShutdownContext(ctx)
|
dnsServer.Shutdown()
|
||||||
go startDns(ctx, l, c)
|
go startDns(l, c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,24 +38,24 @@ func TestParsequery(t *testing.T) {
|
|||||||
func Test_getDnsServerAddr(t *testing.T) {
|
func Test_getDnsServerAddr(t *testing.T) {
|
||||||
c := config.NewC(nil)
|
c := config.NewC(nil)
|
||||||
|
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||||
"dns": map[string]any{
|
"dns": map[interface{}]interface{}{
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": "1",
|
"port": "1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
|
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
|
||||||
|
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||||
"dns": map[string]any{
|
"dns": map[interface{}]interface{}{
|
||||||
"host": "::",
|
"host": "::",
|
||||||
"port": "1",
|
"port": "1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||||
|
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||||
"dns": map[string]any{
|
"dns": map[interface{}]interface{}{
|
||||||
"host": "[::]",
|
"host": "[::]",
|
||||||
"port": "1",
|
"port": "1",
|
||||||
},
|
},
|
||||||
@@ -63,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) {
|
|||||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||||
|
|
||||||
// Make sure whitespace doesn't mess us up
|
// Make sure whitespace doesn't mess us up
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||||
"dns": map[string]any{
|
"dns": map[interface{}]interface{}{
|
||||||
"host": "[::] ",
|
"host": "[::] ",
|
||||||
"port": "1",
|
"port": "1",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"gopkg.in/yaml.v2"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkHotPath(b *testing.B) {
|
func BenchmarkHotPath(b *testing.B) {
|
||||||
@@ -772,7 +771,7 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
"key": string(myNextPrivKey),
|
"key": string(myNextPrivKey),
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(relayConfig.Settings)
|
rc, err := yaml.Marshal(relayConfig.Settings)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
relayConfig.ReloadConfigString(string(rc))
|
relayConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -876,7 +875,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
"key": string(myNextPrivKey),
|
"key": string(myNextPrivKey),
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(relayConfig.Settings)
|
rc, err := yaml.Marshal(relayConfig.Settings)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
relayConfig.ReloadConfigString(string(rc))
|
relayConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -971,7 +970,7 @@ func TestRehandshaking(t *testing.T) {
|
|||||||
"key": string(myNextPrivKey),
|
"key": string(myNextPrivKey),
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(myConfig.Settings)
|
rc, err := yaml.Marshal(myConfig.Settings)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
myConfig.ReloadConfigString(string(rc))
|
myConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -988,17 +987,17 @@ func TestRehandshaking(t *testing.T) {
|
|||||||
r.Log("Got the new cert")
|
r.Log("Got the new cert")
|
||||||
// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
|
// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
|
||||||
rc, err = yaml.Marshal(theirConfig.Settings)
|
rc, err = yaml.Marshal(theirConfig.Settings)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
var theirNewConfig m
|
var theirNewConfig m
|
||||||
require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
|
assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
|
||||||
theirFirewall := theirNewConfig["firewall"].(map[string]any)
|
theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
|
||||||
theirFirewall["inbound"] = []m{{
|
theirFirewall["inbound"] = []m{{
|
||||||
"proto": "any",
|
"proto": "any",
|
||||||
"port": "any",
|
"port": "any",
|
||||||
"group": "new group",
|
"group": "new group",
|
||||||
}}
|
}}
|
||||||
rc, err = yaml.Marshal(theirNewConfig)
|
rc, err = yaml.Marshal(theirNewConfig)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
theirConfig.ReloadConfigString(string(rc))
|
theirConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
r.Log("Spin until there is only 1 tunnel")
|
r.Log("Spin until there is only 1 tunnel")
|
||||||
@@ -1068,7 +1067,7 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
"key": string(theirNextPrivKey),
|
"key": string(theirNextPrivKey),
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(theirConfig.Settings)
|
rc, err := yaml.Marshal(theirConfig.Settings)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
theirConfig.ReloadConfigString(string(rc))
|
theirConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -1084,17 +1083,17 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
|
|
||||||
// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
|
// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
|
||||||
rc, err = yaml.Marshal(myConfig.Settings)
|
rc, err = yaml.Marshal(myConfig.Settings)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
var myNewConfig m
|
var myNewConfig m
|
||||||
require.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
|
assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
|
||||||
theirFirewall := myNewConfig["firewall"].(map[string]any)
|
theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
|
||||||
theirFirewall["inbound"] = []m{{
|
theirFirewall["inbound"] = []m{{
|
||||||
"proto": "any",
|
"proto": "any",
|
||||||
"port": "any",
|
"port": "any",
|
||||||
"group": "their new group",
|
"group": "their new group",
|
||||||
}}
|
}}
|
||||||
rc, err = yaml.Marshal(myNewConfig)
|
rc, err = yaml.Marshal(myNewConfig)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
myConfig.ReloadConfigString(string(rc))
|
myConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
r.Log("Spin until there is only 1 tunnel")
|
r.Log("Spin until there is only 1 tunnel")
|
||||||
|
|||||||
@@ -22,10 +22,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m map[string]interface{}
|
||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
|||||||
@@ -13,11 +13,11 @@ pki:
|
|||||||
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
||||||
#disconnect_invalid: true
|
#disconnect_invalid: true
|
||||||
|
|
||||||
# initiating_version controls which certificate version is used when initiating handshakes.
|
# default_version controls which certificate version is used in handshakes.
|
||||||
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
|
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
|
||||||
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
|
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
|
||||||
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
||||||
# initiating_version: 1
|
# default_version: 1
|
||||||
|
|
||||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||||
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
||||||
@@ -126,8 +126,8 @@ lighthouse:
|
|||||||
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
|
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
|
||||||
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
|
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
|
||||||
listen:
|
listen:
|
||||||
# To listen on only ipv4, use "0.0.0.0"
|
# To listen on both any ipv4 and ipv6 use "::"
|
||||||
host: "::"
|
host: 0.0.0.0
|
||||||
port: 4242
|
port: 4242
|
||||||
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
||||||
# default is 64, does not support reload
|
# default is 64, does not support reload
|
||||||
@@ -144,11 +144,6 @@ listen:
|
|||||||
# valid values: always, never, private
|
# valid values: always, never, private
|
||||||
# This setting is reloadable.
|
# This setting is reloadable.
|
||||||
#send_recv_error: always
|
#send_recv_error: always
|
||||||
# The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier.
|
|
||||||
# This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes,
|
|
||||||
# allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set.
|
|
||||||
# This setting is reloadable.
|
|
||||||
#so_mark: 0
|
|
||||||
|
|
||||||
# Routines is the number of thread pairs to run that consume from the tun and UDP queues.
|
# Routines is the number of thread pairs to run that consume from the tun and UDP queues.
|
||||||
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1
|
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1
|
||||||
@@ -239,28 +234,7 @@ tun:
|
|||||||
|
|
||||||
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
|
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
|
||||||
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
|
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
|
||||||
# Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
|
# NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
|
||||||
# NOTES:
|
|
||||||
# * You will only see a single gateway in the routing table if you are not on linux
|
|
||||||
# * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
|
|
||||||
#
|
|
||||||
# unsafe_routes:
|
|
||||||
# # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
|
|
||||||
# - route: 192.168.87.0/24
|
|
||||||
# via:
|
|
||||||
# - gateway: 10.0.0.1
|
|
||||||
# - gateway: 10.0.0.2
|
|
||||||
# - gateway: 10.0.0.3
|
|
||||||
# # Multiple gateways with a weight, this will balance traffic accordingly
|
|
||||||
# - route: 192.168.87.0/24
|
|
||||||
# via:
|
|
||||||
# - gateway: 10.0.0.1
|
|
||||||
# weight: 10
|
|
||||||
# - gateway: 10.0.0.2
|
|
||||||
# weight: 5
|
|
||||||
#
|
|
||||||
# NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
|
|
||||||
# `via`: single node or list of gateways to use for this route
|
|
||||||
# `mtu`: will default to tun mtu if this option is not specified
|
# `mtu`: will default to tun mtu if this option is not specified
|
||||||
# `metric`: will default to 0 if this option is not specified
|
# `metric`: will default to 0 if this option is not specified
|
||||||
# `install`: will default to true, controls whether this route is installed in the systems routing table.
|
# `install`: will default to true, controls whether this route is installed in the systems routing table.
|
||||||
@@ -346,11 +320,11 @@ firewall:
|
|||||||
outbound_action: drop
|
outbound_action: drop
|
||||||
inbound_action: drop
|
inbound_action: drop
|
||||||
|
|
||||||
# THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.)
|
# Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
|
||||||
# This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a
|
# This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
|
||||||
# `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule
|
# unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
|
||||||
# will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr`
|
# of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
|
||||||
# is explicitly defined. This is usually not the desired behavior and should be avoided!
|
# if the intention is to allow traffic to flow to an unsafe route.
|
||||||
#default_local_cidr_any: false
|
#default_local_cidr_any: false
|
||||||
|
|
||||||
conntrack:
|
conntrack:
|
||||||
@@ -368,9 +342,11 @@ firewall:
|
|||||||
# group: `any` or a literal group name, ie `default-group`
|
# group: `any` or a literal group name, ie `default-group`
|
||||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||||
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
|
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
|
||||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
|
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
|
||||||
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
|
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
|
||||||
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
|
# Otherwise the default is any vpn network assigned to via the certificate.
|
||||||
|
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
|
||||||
|
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
|
||||||
# ca_name: An issuing CA name
|
# ca_name: An issuing CA name
|
||||||
# ca_sha: An issuing CA shasum
|
# ca_sha: An issuing CA shasum
|
||||||
|
|
||||||
|
|||||||
50
firewall.go
50
firewall.go
@@ -53,7 +53,7 @@ type Firewall struct {
|
|||||||
|
|
||||||
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
||||||
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
||||||
routableNetworks *bart.Lite
|
routableNetworks *bart.Table[struct{}]
|
||||||
|
|
||||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||||
assignedNetworks []netip.Prefix
|
assignedNetworks []netip.Prefix
|
||||||
@@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA
|
|||||||
|
|
||||||
type firewallLocalCIDR struct {
|
type firewallLocalCIDR struct {
|
||||||
Any bool
|
Any bool
|
||||||
LocalCIDR *bart.Lite
|
LocalCIDR *bart.Table[struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
@@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||||||
tmax = defaultTimeout
|
tmax = defaultTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
routableNetworks := new(bart.Lite)
|
routableNetworks := new(bart.Table[struct{}])
|
||||||
var assignedNetworks []netip.Prefix
|
var assignedNetworks []netip.Prefix
|
||||||
for _, network := range c.Networks() {
|
for _, network := range c.Networks() {
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
routableNetworks.Insert(nprefix)
|
routableNetworks.Insert(nprefix, struct{}{})
|
||||||
assignedNetworks = append(assignedNetworks, network)
|
assignedNetworks = append(assignedNetworks, network)
|
||||||
}
|
}
|
||||||
|
|
||||||
hasUnsafeNetworks := false
|
hasUnsafeNetworks := false
|
||||||
for _, n := range c.UnsafeNetworks() {
|
for _, n := range c.UnsafeNetworks() {
|
||||||
routableNetworks.Insert(n)
|
routableNetworks.Insert(n, struct{}{})
|
||||||
hasUnsafeNetworks = true
|
hasUnsafeNetworks = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rs, ok := r.([]any)
|
rs, ok := r.([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("%s failed to parse, should be an array of rules", table)
|
return fmt.Errorf("%s failed to parse, should be an array of rules", table)
|
||||||
}
|
}
|
||||||
@@ -431,7 +431,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
|
|
||||||
// Make sure remote address matches nebula certificate
|
// Make sure remote address matches nebula certificate
|
||||||
if h.networks != nil {
|
if h.networks != nil {
|
||||||
if !h.networks.Contains(fp.RemoteAddr) {
|
_, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||||
|
if !ok {
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
@@ -444,7 +445,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
if !f.routableNetworks.Contains(fp.LocalAddr) {
|
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
|
||||||
|
if !ok {
|
||||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||||
return ErrInvalidLocalIP
|
return ErrInvalidLocalIP
|
||||||
}
|
}
|
||||||
@@ -750,7 +752,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
|
|||||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
|
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
|
||||||
flc := func() *firewallLocalCIDR {
|
flc := func() *firewallLocalCIDR {
|
||||||
return &firewallLocalCIDR{
|
return &firewallLocalCIDR{
|
||||||
LocalCIDR: new(bart.Lite),
|
LocalCIDR: new(bart.Table[struct{}]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -860,13 +862,16 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
|
matched := false
|
||||||
if v.match(p, c) {
|
prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
|
||||||
return true
|
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
|
||||||
|
if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
|
||||||
|
matched = true
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
return false
|
return matched
|
||||||
}
|
}
|
||||||
|
|
||||||
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
||||||
@@ -877,7 +882,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range f.assignedNetworks {
|
for _, network := range f.assignedNetworks {
|
||||||
flc.LocalCIDR.Insert(network)
|
flc.LocalCIDR.Insert(network, struct{}{})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
@@ -886,7 +891,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
flc.LocalCIDR.Insert(localIp)
|
flc.LocalCIDR.Insert(localIp, struct{}{})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -899,7 +904,8 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return flc.LocalCIDR.Contains(p.LocalAddr)
|
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
|
||||||
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
type rule struct {
|
type rule struct {
|
||||||
@@ -915,15 +921,15 @@ type rule struct {
|
|||||||
CASha string
|
CASha string
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
|
||||||
r := rule{}
|
r := rule{}
|
||||||
|
|
||||||
m, ok := p.(map[string]any)
|
m, ok := p.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return r, errors.New("could not parse rule")
|
return r, errors.New("could not parse rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
toString := func(k string, m map[string]any) string {
|
toString := func(k string, m map[interface{}]interface{}) string {
|
||||||
v, ok := m[k]
|
v, ok := m[k]
|
||||||
if !ok {
|
if !ok {
|
||||||
return ""
|
return ""
|
||||||
@@ -941,7 +947,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
|||||||
r.CASha = toString("ca_sha", m)
|
r.CASha = toString("ca_sha", m)
|
||||||
|
|
||||||
// Make sure group isn't an array
|
// Make sure group isn't an array
|
||||||
if v, ok := m["group"].([]any); ok {
|
if v, ok := m["group"].([]interface{}); ok {
|
||||||
if len(v) > 1 {
|
if len(v) > 1 {
|
||||||
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m map[string]interface{}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
||||||
|
|||||||
205
firewall_test.go
205
firewall_test.go
@@ -66,61 +66,61 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
assert.NotNil(t, fw.OutRules)
|
assert.NotNil(t, fw.OutRules)
|
||||||
|
|
||||||
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
ti, err := netip.ParsePrefix("1.2.3.4/32")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
// An empty rule is any
|
// An empty rule is any
|
||||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
||||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
||||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
@@ -155,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
// Allow outbound because conntrack
|
// Allow outbound because conntrack
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||||
|
|
||||||
// test remote mismatch
|
// test remote mismatch
|
||||||
oldRemote := p.RemoteAddr
|
oldRemote := p.RemoteAddr
|
||||||
@@ -174,29 +174,29 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
@@ -350,14 +350,14 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// h1/c1 lacks the proper groups
|
// h1/c1 lacks the proper groups
|
||||||
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
|
assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||||
// c has the proper groups
|
// c has the proper groups
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
@@ -428,23 +428,18 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// c1 should pass because host match
|
// c1 should pass because host match
|
||||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||||
// c2 should pass because ca sha match
|
// c2 should pass because ca sha match
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
|
assert.NoError(t, fw.Drop(p, true, &h2, cp, nil))
|
||||||
// c3 should fail because no match
|
// c3 should fail because no match
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// Test a remote address match
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
|
|
||||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
@@ -480,29 +475,29 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
// Allow outbound because conntrack
|
// Allow outbound because conntrack
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||||
|
|
||||||
oldFw := fw
|
oldFw := fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
// Allow outbound because conntrack and new rules allow port 10
|
// Allow outbound because conntrack and new rules allow port 10
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||||
|
|
||||||
oldFw = fw
|
oldFw = fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
@@ -585,42 +580,42 @@ func BenchmarkLookup(b *testing.B) {
|
|||||||
|
|
||||||
func Test_parsePort(t *testing.T) {
|
func Test_parsePort(t *testing.T) {
|
||||||
_, _, err := parsePort("")
|
_, _, err := parsePort("")
|
||||||
require.EqualError(t, err, "was not a number; ``")
|
assert.EqualError(t, err, "was not a number; ``")
|
||||||
|
|
||||||
_, _, err = parsePort(" ")
|
_, _, err = parsePort(" ")
|
||||||
require.EqualError(t, err, "was not a number; ` `")
|
assert.EqualError(t, err, "was not a number; ` `")
|
||||||
|
|
||||||
_, _, err = parsePort("-")
|
_, _, err = parsePort("-")
|
||||||
require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
|
assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
|
||||||
|
|
||||||
_, _, err = parsePort(" - ")
|
_, _, err = parsePort(" - ")
|
||||||
require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
|
assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
|
||||||
|
|
||||||
_, _, err = parsePort("a-b")
|
_, _, err = parsePort("a-b")
|
||||||
require.EqualError(t, err, "beginning range was not a number; `a`")
|
assert.EqualError(t, err, "beginning range was not a number; `a`")
|
||||||
|
|
||||||
_, _, err = parsePort("1-b")
|
_, _, err = parsePort("1-b")
|
||||||
require.EqualError(t, err, "ending range was not a number; `b`")
|
assert.EqualError(t, err, "ending range was not a number; `b`")
|
||||||
|
|
||||||
s, e, err := parsePort(" 1 - 2 ")
|
s, e, err := parsePort(" 1 - 2 ")
|
||||||
assert.Equal(t, int32(1), s)
|
assert.Equal(t, int32(1), s)
|
||||||
assert.Equal(t, int32(2), e)
|
assert.Equal(t, int32(2), e)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
s, e, err = parsePort("0-1")
|
s, e, err = parsePort("0-1")
|
||||||
assert.Equal(t, int32(0), s)
|
assert.Equal(t, int32(0), s)
|
||||||
assert.Equal(t, int32(0), e)
|
assert.Equal(t, int32(0), e)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
s, e, err = parsePort("9919")
|
s, e, err = parsePort("9919")
|
||||||
assert.Equal(t, int32(9919), s)
|
assert.Equal(t, int32(9919), s)
|
||||||
assert.Equal(t, int32(9919), e)
|
assert.Equal(t, int32(9919), e)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
s, e, err = parsePort("any")
|
s, e, err = parsePort("any")
|
||||||
assert.Equal(t, int32(0), s)
|
assert.Equal(t, int32(0), s)
|
||||||
assert.Equal(t, int32(0), e)
|
assert.Equal(t, int32(0), e)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewFirewallFromConfig(t *testing.T) {
|
func TestNewFirewallFromConfig(t *testing.T) {
|
||||||
@@ -631,55 +626,55 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||||
|
|
||||||
// Test both port and code
|
// Test both port and code
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||||
|
|
||||||
// Test missing host, group, cidr, ca_name and ca_sha
|
// Test missing host, group, cidr, ca_name and ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||||
|
|
||||||
// Test code/port error
|
// Test code/port error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||||
|
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||||
|
|
||||||
// Test proto error
|
// Test proto error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||||
|
|
||||||
// Test cidr parse error
|
// Test cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test local_cidr parse error
|
// Test local_cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test both group and groups
|
// Test both group and groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||||
@@ -687,87 +682,87 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
// Test adding tcp rule
|
// Test adding tcp rule
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(l)
|
||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding udp rule
|
// Test adding udp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule
|
// Test adding icmp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding any rule
|
// Test adding any rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with cidr
|
// Test adding rule with cidr
|
||||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with local_cidr
|
// Test adding rule with local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_name
|
// Test adding rule with ca_name
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
|
||||||
|
|
||||||
// Test single group
|
// Test single group
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test single groups
|
// Test single groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test multiple AND groups
|
// Test multiple AND groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test Add error
|
// Test Add error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
mf.nextCallReturn = errors.New("test error")
|
mf.nextCallReturn = errors.New("test error")
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
|
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRule(t *testing.T) {
|
func TestFirewall_convertRule(t *testing.T) {
|
||||||
@@ -776,33 +771,33 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
// Ensure group array of 1 is converted and a warning is printed
|
// Ensure group array of 1 is converted and a warning is printed
|
||||||
c := map[string]any{
|
c := map[interface{}]interface{}{
|
||||||
"group": []any{"group1"},
|
"group": []interface{}{"group1"},
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
r, err := convertRule(l, c, "test", 1)
|
||||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "group1", r.Group)
|
assert.Equal(t, "group1", r.Group)
|
||||||
|
|
||||||
// Ensure group array of > 1 is errord
|
// Ensure group array of > 1 is errord
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
c = map[string]any{
|
c = map[interface{}]interface{}{
|
||||||
"group": []any{"group1", "group2"},
|
"group": []interface{}{"group1", "group2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err = convertRule(l, c, "test", 1)
|
r, err = convertRule(l, c, "test", 1)
|
||||||
assert.Empty(t, ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
|
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
|
||||||
|
|
||||||
// Make sure a well formed group is alright
|
// Make sure a well formed group is alright
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
c = map[string]any{
|
c = map[interface{}]interface{}{
|
||||||
"group": "group1",
|
"group": "group1",
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err = convertRule(l, c, "test", 1)
|
r, err = convertRule(l, c, "test", 1)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "group1", r.Group)
|
assert.Equal(t, "group1", r.Group)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
35
go.mod
35
go.mod
@@ -1,8 +1,8 @@
|
|||||||
module github.com/slackhq/nebula
|
module github.com/slackhq/nebula
|
||||||
|
|
||||||
go 1.23.0
|
go 1.22.0
|
||||||
|
|
||||||
toolchain go1.24.1
|
toolchain go1.22.2
|
||||||
|
|
||||||
require (
|
require (
|
||||||
dario.cat/mergo v1.0.1
|
dario.cat/mergo v1.0.1
|
||||||
@@ -10,46 +10,49 @@ require (
|
|||||||
github.com/armon/go-radix v1.0.0
|
github.com/armon/go-radix v1.0.0
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||||
github.com/flynn/noise v1.1.0
|
github.com/flynn/noise v1.1.0
|
||||||
github.com/gaissmai/bart v0.20.4
|
github.com/gaissmai/bart v0.13.0
|
||||||
github.com/gogo/protobuf v1.3.2
|
github.com/gogo/protobuf v1.3.2
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/kardianos/service v1.2.2
|
github.com/kardianos/service v1.2.2
|
||||||
github.com/miekg/dns v1.1.65
|
github.com/miekg/dns v1.1.62
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
||||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||||
github.com/prometheus/client_golang v1.22.0
|
github.com/prometheus/client_golang v1.20.4
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/vishvananda/netlink v1.3.0
|
github.com/vishvananda/netlink v1.3.0
|
||||||
golang.org/x/crypto v0.37.0
|
golang.org/x/crypto v0.28.0
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.39.0
|
golang.org/x/net v0.30.0
|
||||||
golang.org/x/sync v0.13.0
|
golang.org/x/sync v0.8.0
|
||||||
golang.org/x/sys v0.32.0
|
golang.org/x/sys v0.26.0
|
||||||
golang.org/x/term v0.31.0
|
golang.org/x/term v0.25.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/protobuf v1.36.6
|
google.golang.org/protobuf v1.35.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
|
github.com/bits-and-blooms/bitset v1.14.3 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/google/btree v1.1.2 // indirect
|
github.com/google/btree v1.1.2 // indirect
|
||||||
|
github.com/klauspost/compress v1.17.9 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/prometheus/client_model v0.6.1 // indirect
|
github.com/prometheus/client_model v0.6.1 // indirect
|
||||||
github.com/prometheus/common v0.62.0 // indirect
|
github.com/prometheus/common v0.55.0 // indirect
|
||||||
github.com/prometheus/procfs v0.15.1 // indirect
|
github.com/prometheus/procfs v0.15.1 // indirect
|
||||||
github.com/vishvananda/netns v0.0.4 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
golang.org/x/mod v0.23.0 // indirect
|
golang.org/x/mod v0.18.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.30.0 // indirect
|
golang.org/x/tools v0.22.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
64
go.sum
64
go.sum
@@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
|
|||||||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
|
github.com/bits-and-blooms/bitset v1.14.3 h1:Gd2c8lSNf9pKXom5JtD7AaKO8o7fGQ2LtFj1436qilA=
|
||||||
|
github.com/bits-and-blooms/bitset v1.14.3/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
|
||||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
@@ -24,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||||
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
|
github.com/gaissmai/bart v0.13.0 h1:pItEhXDVVebUa+i978FfQ7ye8xZc1FrMgs8nJPPWAgA=
|
||||||
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
|
github.com/gaissmai/bart v0.13.0/go.mod h1:qSes2fnJ8hB410BW0ymHUN/eQkuGpTYyJcN8sKMYpJU=
|
||||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||||
@@ -53,8 +55,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
|||||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||||
@@ -68,8 +70,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
|
|||||||
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||||
@@ -83,8 +85,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
|
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||||
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
|
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -106,8 +108,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
|
|||||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||||
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
|
github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI=
|
||||||
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
|
github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
||||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
@@ -116,8 +118,8 @@ github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQy
|
|||||||
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||||
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
||||||
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
||||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
|
||||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
|
||||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||||
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||||
@@ -143,8 +145,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
|||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||||
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||||
@@ -156,16 +158,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
|
||||||
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -176,8 +178,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -185,8 +187,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||||
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -204,11 +206,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
|
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
|
||||||
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
|
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
@@ -219,8 +221,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
|||||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
|
||||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@@ -239,8 +241,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
|||||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
|
||||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
@@ -251,6 +253,8 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|||||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
113
handshake_ix.go
113
handshake_ix.go
@@ -25,7 +25,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
// If we're connecting to a v6 address we must use a v2 cert
|
// If we're connecting to a v6 address we must use a v2 cert
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
v := cs.initiatingVersion
|
v := cs.defaultVersion
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
if a.Is6() {
|
if a.Is6() {
|
||||||
v = cert.Version2
|
v = cert.Version2
|
||||||
@@ -71,8 +71,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
hsBytes, err := hs.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
|
||||||
WithField("certVersion", v).
|
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -101,7 +100,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if crt == nil {
|
if crt == nil {
|
||||||
f.l.WithField("udpAddr", addr).
|
f.l.WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
WithField("certVersion", cs.defaultVersion).
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,28 +132,13 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Info("Handshake did not contain a certificate")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
|
||||||
if err != nil {
|
|
||||||
fp, err := rc.Fingerprint()
|
|
||||||
if err != nil {
|
|
||||||
fp = "<error generating certificate fingerprint>"
|
|
||||||
}
|
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
||||||
WithField("certVpnNetworks", rc.Networks()).
|
|
||||||
WithField("certFingerprint", fp)
|
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level > logrus.DebugLevel {
|
||||||
e = e.WithField("cert", rc)
|
e = e.WithField("cert", remoteCert)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
e.Info("Invalid certificate from host")
|
||||||
@@ -176,26 +160,29 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("cert", remoteCert).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Info("No networks in certificate")
|
if f.l.Level > logrus.DebugLevel {
|
||||||
|
e = e.WithField("cert", remoteCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.Info("Invalid vpn ip from host")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
var vpnAddrs []netip.Addr
|
||||||
var filteredNetworks []netip.Prefix
|
var filteredNetworks []netip.Prefix
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
certVersion := remoteCert.Certificate.Version()
|
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
|
|
||||||
for _, network := range remoteCert.Certificate.Networks() {
|
for _, network := range remoteCert.Certificate.Networks() {
|
||||||
vpnAddr := network.Addr()
|
vpnAddr := network.Addr()
|
||||||
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
|
||||||
|
if found {
|
||||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
@@ -203,7 +190,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
}
|
}
|
||||||
|
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,7 +201,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if len(vpnAddrs) == 0 {
|
if len(vpnAddrs) == 0 {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||||
@@ -234,7 +220,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||||
@@ -257,7 +242,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -269,7 +253,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -287,7 +270,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
@@ -299,7 +281,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
@@ -307,7 +288,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
||||||
@@ -375,7 +355,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -391,7 +370,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -404,7 +382,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
// And we forget to update it here
|
// And we forget to update it here
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -421,7 +398,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -430,7 +406,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -449,7 +424,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -513,48 +487,35 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
Info("Handshake did not contain a certificate")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
if f.l.Level > logrus.DebugLevel {
|
||||||
if err != nil {
|
e = e.WithField("cert", remoteCert)
|
||||||
fp, err := rc.Fingerprint()
|
|
||||||
if err != nil {
|
|
||||||
fp = "<error generating certificate fingerprint>"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
e.Error("Invalid certificate from host")
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
|
||||||
WithField("certFingerprint", fp).
|
|
||||||
WithField("certVpnNetworks", rc.Networks())
|
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||||
e = e.WithField("cert", rc)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
||||||
WithField("cert", remoteCert).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
if f.l.Level > logrus.DebugLevel {
|
||||||
Info("No networks in certificate")
|
e = e.WithField("cert", remoteCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.Info("Empty networks from host")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnNetworks := remoteCert.Certificate.Networks()
|
vpnNetworks := remoteCert.Certificate.Networks()
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
certVersion := remoteCert.Certificate.Version()
|
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
|
|
||||||
@@ -578,7 +539,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
for _, network := range vpnNetworks {
|
for _, network := range vpnNetworks {
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||||
vpnAddr := network.Addr()
|
vpnAddr := network.Addr()
|
||||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -589,7 +550,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
if len(vpnAddrs) == 0 {
|
if len(vpnAddrs) == 0 {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||||
@@ -599,9 +559,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).WithField("certName", certName).
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Info("Incorrect host responded to handshake")
|
Info("Incorrect host responded to handshake")
|
||||||
|
|
||||||
@@ -637,7 +595,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
|||||||
@@ -257,7 +257,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
} else if hm.l.Level >= logrus.DebugLevel {
|
} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
|
||||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
@@ -274,7 +274,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Don't relay through the host I'm trying to connect to
|
// Don't relay through the host I'm trying to connect to
|
||||||
if hm.f.myVpnAddrsTable.Contains(relay) {
|
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
|
||||||
|
if found {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
initiatingVersion: cert.Version1,
|
defaultVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||||
@@ -44,7 +44,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||||||
i.remotes = NewRemoteList([]netip.Addr{}, nil)
|
i.remotes = NewRemoteList([]netip.Addr{}, nil)
|
||||||
|
|
||||||
// Adding something to pending should not affect the main hostmap
|
// Adding something to pending should not affect the main hostmap
|
||||||
assert.Empty(t, mainHM.Hosts)
|
assert.Len(t, mainHM.Hosts, 0)
|
||||||
|
|
||||||
// Confirm they are in the pending index list
|
// Confirm they are in the pending index list
|
||||||
assert.Contains(t, blah.vpnIps, ip)
|
assert.Contains(t, blah.vpnIps, ip)
|
||||||
@@ -98,5 +98,5 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mw *mockEncWriter) GetCertState() *CertState {
|
func (mw *mockEncWriter) GetCertState() *CertState {
|
||||||
return &CertState{initiatingVersion: cert.Version2}
|
return &CertState{defaultVersion: cert.Version2}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
// |-----------------------------------------------------------------------|
|
// |-----------------------------------------------------------------------|
|
||||||
// | payload... |
|
// | payload... |
|
||||||
|
|
||||||
type m = map[string]any
|
type m map[string]interface{}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Version uint8 = 1
|
Version uint8 = 1
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type headerTest struct {
|
type headerTest struct {
|
||||||
@@ -112,7 +111,7 @@ func TestHeader_String(t *testing.T) {
|
|||||||
|
|
||||||
func TestHeader_MarshalJSON(t *testing.T) {
|
func TestHeader_MarshalJSON(t *testing.T) {
|
||||||
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
|
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
|
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ type HostInfo struct {
|
|||||||
recvError atomic.Uint32
|
recvError atomic.Uint32
|
||||||
|
|
||||||
// networks are both all vpn and unsafe networks assigned to this host
|
// networks are both all vpn and unsafe networks assigned to this host
|
||||||
networks *bart.Lite
|
networks *bart.Table[struct{}]
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@@ -732,13 +732,13 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
i.networks = new(bart.Lite)
|
i.networks = new(bart.Table[struct{}])
|
||||||
for _, network := range networks {
|
for _, network := range networks {
|
||||||
i.networks.Insert(network)
|
i.networks.Insert(network, struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range unsafeNetworks {
|
for _, network := range unsafeNetworks {
|
||||||
i.networks.Insert(network)
|
i.networks.Insert(network, struct{}{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -210,8 +210,8 @@ func TestHostMap_reload(t *testing.T) {
|
|||||||
assert.Empty(t, hm.GetPreferredRanges())
|
assert.Empty(t, hm.GetPreferredRanges())
|
||||||
|
|
||||||
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
|
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
|
||||||
assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
|
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
|
||||||
|
|
||||||
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
||||||
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
||||||
}
|
}
|
||||||
|
|||||||
102
inside.go
102
inside.go
@@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
@@ -22,12 +21,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
// Ignore local broadcast packets
|
// Ignore local broadcast packets
|
||||||
if f.dropLocalBroadcast {
|
if f.dropLocalBroadcast {
|
||||||
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
|
||||||
|
if found {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
|
||||||
|
if found {
|
||||||
// Immediately forward packets from self to self.
|
// Immediately forward packets from self to self.
|
||||||
// This should only happen on Darwin-based and FreeBSD hosts, which
|
// This should only happen on Darwin-based and FreeBSD hosts, which
|
||||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||||
@@ -48,7 +49,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -120,93 +121,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
|
||||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||||
f.getOrHandshakeNoRouting(vpnAddr, nil)
|
f.getOrHandshake(vpnAddr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
// getOrHandshake returns nil if the vpnAddr is not routable.
|
||||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
||||||
func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||||
if f.myVpnNetworksTable.Contains(vpnAddr) {
|
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
|
||||||
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
|
if !found {
|
||||||
}
|
vpnAddr = f.inside.RouteFor(vpnAddr)
|
||||||
|
if !vpnAddr.IsValid() {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
|
||||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
|
||||||
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
|
||||||
|
|
||||||
destinationAddr := fwPacket.RemoteAddr
|
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
|
||||||
|
|
||||||
// Host is inside the mesh, no routing required
|
|
||||||
if hostinfo != nil {
|
|
||||||
return hostinfo, ready
|
|
||||||
}
|
|
||||||
|
|
||||||
gateways := f.inside.RoutesFor(destinationAddr)
|
|
||||||
|
|
||||||
switch len(gateways) {
|
|
||||||
case 0:
|
|
||||||
return nil, false
|
|
||||||
case 1:
|
|
||||||
// Single gateway route
|
|
||||||
return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
|
|
||||||
default:
|
|
||||||
// Multi gateway route, perform ECMP categorization
|
|
||||||
gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
|
|
||||||
|
|
||||||
if !balancingOk {
|
|
||||||
// This happens if the gateway buckets were not calculated, this _should_ never happen
|
|
||||||
f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
|
|
||||||
}
|
|
||||||
|
|
||||||
var handshakeInfoForChosenGateway *HandshakeHostInfo
|
|
||||||
var hhReceiver = func(hh *HandshakeHostInfo) {
|
|
||||||
handshakeInfoForChosenGateway = hh
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the handshakeHostInfo for later.
|
|
||||||
// If this node is not reachable we will attempt other nodes, if none are reachable we will
|
|
||||||
// cache the packet for this gateway.
|
|
||||||
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
|
|
||||||
return hostinfo, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// It appears the selected gateway cannot be reached, find another gateway to fallback on.
|
|
||||||
// The current implementation breaks ECMP but that seems better than no connectivity.
|
|
||||||
// If ECMP is also required when a gateway is down then connectivity status
|
|
||||||
// for each gateway needs to be kept and the weights recalculated when they go up or down.
|
|
||||||
// This would also need to interact with unsafe_route updates through reloading the config or
|
|
||||||
// use of the use_system_route_table option
|
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("destination", destinationAddr).
|
|
||||||
WithField("originalGateway", gatewayAddr).
|
|
||||||
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range gateways {
|
|
||||||
// Skip the gateway that failed previously
|
|
||||||
if gateways[i].Addr() == gatewayAddr {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
|
|
||||||
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
|
|
||||||
return hostinfo, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No gateways reachable, cache the packet in the originally chosen gateway
|
|
||||||
cacheCallback(handshakeInfoForChosenGateway)
|
|
||||||
return hostinfo, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||||
@@ -233,7 +163,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
|
|
||||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||||
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
|
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
14
interface.go
14
interface.go
@@ -61,11 +61,11 @@ type Interface struct {
|
|||||||
serveDns bool
|
serveDns bool
|
||||||
createTime time.Time
|
createTime time.Time
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
myBroadcastAddrsTable *bart.Lite
|
myBroadcastAddrsTable *bart.Table[struct{}]
|
||||||
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
||||||
myVpnAddrsTable *bart.Lite
|
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
|
||||||
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
|
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
|
||||||
myVpnNetworksTable *bart.Lite
|
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
|
||||||
dropLocalBroadcast bool
|
dropLocalBroadcast bool
|
||||||
dropMulticast bool
|
dropMulticast bool
|
||||||
routines int
|
routines int
|
||||||
@@ -410,7 +410,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
|||||||
udpStats := udp.NewUDPStatsEmitter(f.writers)
|
udpStats := udp.NewUDPStatsEmitter(f.writers)
|
||||||
|
|
||||||
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
||||||
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
|
certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
|
||||||
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
|
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -425,7 +425,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
|||||||
certState := f.pki.getCertState()
|
certState := f.pki.getCertState()
|
||||||
defaultCrt := certState.GetDefaultCertificate()
|
defaultCrt := certState.GetDefaultCertificate()
|
||||||
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
||||||
certInitiatingVersion.Update(int64(defaultCrt.Version()))
|
certDefaultVersion.Update(int64(defaultCrt.Version()))
|
||||||
|
|
||||||
// Report the max certificate version we are capable of using
|
// Report the max certificate version we are capable of using
|
||||||
if certState.v2Cert != nil {
|
if certState.v2Cert != nil {
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type LightHouse struct {
|
|||||||
amLighthouse bool
|
amLighthouse bool
|
||||||
|
|
||||||
myVpnNetworks []netip.Prefix
|
myVpnNetworks []netip.Prefix
|
||||||
myVpnNetworksTable *bart.Lite
|
myVpnNetworksTable *bart.Table[struct{}]
|
||||||
punchConn udp.Conn
|
punchConn udp.Conn
|
||||||
punchy *Punchy
|
punchy *Punchy
|
||||||
|
|
||||||
@@ -201,7 +201,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
|
|
||||||
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
||||||
addr := addrs[0].Unmap()
|
addr := addrs[0].Unmap()
|
||||||
if lh.myVpnNetworksTable.Contains(addr) {
|
_, found := lh.myVpnNetworksTable.Lookup(addr)
|
||||||
|
if found {
|
||||||
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
||||||
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
||||||
continue
|
continue
|
||||||
@@ -358,7 +359,8 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{
|
|||||||
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
_, found := lh.myVpnNetworksTable.Lookup(addr)
|
||||||
|
if !found {
|
||||||
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||||
}
|
}
|
||||||
lhMap[addr] = struct{}{}
|
lhMap[addr] = struct{}{}
|
||||||
@@ -420,7 +422,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
shm := c.GetMap("static_host_map", map[string]any{})
|
shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
|
||||||
i := 0
|
i := 0
|
||||||
|
|
||||||
for k, v := range shm {
|
for k, v := range shm {
|
||||||
@@ -429,13 +431,14 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
|
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
_, found := lh.myVpnNetworksTable.Lookup(vpnAddr)
|
||||||
|
if !found {
|
||||||
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
vals, ok := v.([]any)
|
vals, ok := v.([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
vals = []any{v}
|
vals = []interface{}{v}
|
||||||
}
|
}
|
||||||
remoteAddrs := []string{}
|
remoteAddrs := []string{}
|
||||||
for _, v := range vals {
|
for _, v := range vals {
|
||||||
@@ -650,7 +653,8 @@ func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if lh.myVpnNetworksTable.Contains(to) {
|
_, found := lh.myVpnNetworksTable.Lookup(to)
|
||||||
|
if found {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -670,7 +674,8 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
|
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
|
||||||
|
if found {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -690,7 +695,8 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
|
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
|
||||||
|
if found {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -757,7 +763,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
if hi != nil {
|
if hi != nil {
|
||||||
v = hi.ConnectionState.myCert.Version()
|
v = hi.ConnectionState.myCert.Version()
|
||||||
} else {
|
} else {
|
||||||
v = lh.ifce.GetCertState().initiatingVersion
|
v = lh.ifce.GetCertState().defaultVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
@@ -850,7 +856,8 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
|
|
||||||
lal := lh.GetLocalAllowList()
|
lal := lh.GetLocalAllowList()
|
||||||
for _, e := range localAddrs(lh.l, lal) {
|
for _, e := range localAddrs(lh.l, lal) {
|
||||||
if lh.myVpnNetworksTable.Contains(e) {
|
_, found := lh.myVpnNetworksTable.Lookup(e)
|
||||||
|
if found {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -876,7 +883,7 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
if hi != nil {
|
if hi != nil {
|
||||||
v = hi.ConnectionState.myCert.Version()
|
v = hi.ConnectionState.myCert.Version()
|
||||||
} else {
|
} else {
|
||||||
v = lh.ifce.GetCertState().initiatingVersion
|
v = lh.ifce.GetCertState().defaultVersion
|
||||||
}
|
}
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if v1Update == nil {
|
if v1Update == nil {
|
||||||
@@ -1107,7 +1114,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
||||||
var useVersion cert.Version
|
var useVersion cert.Version
|
||||||
if targetHI == nil {
|
if targetHI == nil {
|
||||||
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
|
||||||
} else {
|
} else {
|
||||||
crt := targetHI.GetCert().Certificate
|
crt := targetHI.GetCert().Certificate
|
||||||
useVersion = crt.Version()
|
useVersion = crt.Version()
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"gopkg.in/yaml.v2"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOldIPv4Only(t *testing.T) {
|
func TestOldIPv4Only(t *testing.T) {
|
||||||
@@ -22,7 +21,7 @@ func TestOldIPv4Only(t *testing.T) {
|
|||||||
b := []byte{8, 129, 130, 132, 80, 16, 10}
|
b := []byte{8, 129, 130, 132, 80, 16, 10}
|
||||||
var m V4AddrPort
|
var m V4AddrPort
|
||||||
err := m.Unmarshal(b)
|
err := m.Unmarshal(b)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
ip := netip.MustParseAddr("10.1.1.1")
|
ip := netip.MustParseAddr("10.1.1.1")
|
||||||
bp := ip.As4()
|
bp := ip.As4()
|
||||||
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
|
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
|
||||||
@@ -31,8 +30,8 @@ func TestOldIPv4Only(t *testing.T) {
|
|||||||
func Test_lhStaticMapping(t *testing.T) {
|
func Test_lhStaticMapping(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||||
nt := new(bart.Lite)
|
nt := new(bart.Table[struct{}])
|
||||||
nt.Insert(myVpnNet)
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
@@ -40,24 +39,24 @@ func Test_lhStaticMapping(t *testing.T) {
|
|||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
|
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
|
||||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||||
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
lh2 := "10.128.0.3"
|
lh2 := "10.128.0.3"
|
||||||
c = config.NewC(l)
|
c = config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
|
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
|
||||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
|
||||||
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReloadLighthouseInterval(t *testing.T) {
|
func TestReloadLighthouseInterval(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||||
nt := new(bart.Lite)
|
nt := new(bart.Table[struct{}])
|
||||||
nt.Insert(myVpnNet)
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
@@ -65,34 +64,34 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
|||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[string]any{
|
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||||
"hosts": []any{lh1},
|
"hosts": []interface{}{lh1},
|
||||||
"interval": "1s",
|
"interval": "1s",
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
|
|
||||||
// The first one routine is kicked off by main.go currently, lets make sure that one dies
|
// The first one routine is kicked off by main.go currently, lets make sure that one dies
|
||||||
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
|
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
|
||||||
assert.Equal(t, int64(5), lh.interval.Load())
|
assert.Equal(t, int64(5), lh.interval.Load())
|
||||||
|
|
||||||
// Subsequent calls are killed off by the LightHouse.Reload function
|
// Subsequent calls are killed off by the LightHouse.Reload function
|
||||||
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
|
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
|
||||||
assert.Equal(t, int64(10), lh.interval.Load())
|
assert.Equal(t, int64(10), lh.interval.Load())
|
||||||
|
|
||||||
// If this completes then nothing is stealing our reload routine
|
// If this completes then nothing is stealing our reload routine
|
||||||
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
|
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
|
||||||
assert.Equal(t, int64(11), lh.interval.Load())
|
assert.Equal(t, int64(11), lh.interval.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
||||||
nt := new(bart.Lite)
|
nt := new(bart.Table[struct{}])
|
||||||
nt.Insert(myVpnNet)
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
@@ -100,7 +99,9 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
require.NoError(b, err)
|
if !assert.NoError(b, err) {
|
||||||
|
b.Fatal()
|
||||||
|
}
|
||||||
|
|
||||||
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
|
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
|
||||||
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
|
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
|
||||||
@@ -144,7 +145,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
require.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||||
}
|
}
|
||||||
@@ -159,7 +160,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
require.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||||
@@ -192,19 +193,19 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||||||
theirVpnIp := netip.MustParseAddr("10.128.0.3")
|
theirVpnIp := netip.MustParseAddr("10.128.0.3")
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
|
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||||
|
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||||
nt := new(bart.Lite)
|
nt := new(bart.Table[struct{}])
|
||||||
nt.Insert(myVpnNet)
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
}
|
}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
|
|
||||||
// Test that my first update responds with just that
|
// Test that my first update responds with just that
|
||||||
@@ -277,31 +278,31 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||||||
func TestLighthouse_reload(t *testing.T) {
|
func TestLighthouse_reload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
|
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||||
|
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||||
nt := new(bart.Lite)
|
nt := new(bart.Table[struct{}])
|
||||||
nt.Insert(myVpnNet)
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
}
|
}
|
||||||
|
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
nc := map[string]any{
|
nc := map[interface{}]interface{}{
|
||||||
"static_host_map": map[string]any{
|
"static_host_map": map[interface{}]interface{}{
|
||||||
"10.128.0.2": []any{"1.1.1.1:4242"},
|
"10.128.0.2": []interface{}{"1.1.1.1:4242"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(nc)
|
rc, err := yaml.Marshal(nc)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
c.ReloadConfigString(string(rc))
|
c.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
err = lh.reload(c, false)
|
err = lh.reload(c, false)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
|
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
|
||||||
@@ -417,7 +418,7 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tw *testEncWriter) GetCertState() *CertState {
|
func (tw *testEncWriter) GetCertState() *CertState {
|
||||||
return &CertState{initiatingVersion: tw.protocolVersion}
|
return &CertState{defaultVersion: tw.protocolVersion}
|
||||||
}
|
}
|
||||||
|
|
||||||
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
||||||
|
|||||||
6
main.go
6
main.go
@@ -13,10 +13,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m map[string]interface{}
|
||||||
|
|
||||||
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@@ -284,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
var dnsStart func()
|
var dnsStart func()
|
||||||
if lightHouse.amLighthouse && serveDns {
|
if lightHouse.amLighthouse && serveDns {
|
||||||
l.Debugln("Starting dns server")
|
l.Debugln("Starting dns server")
|
||||||
dnsStart = dnsMain(ctx, l, pki.getCertState(), hostMap, c)
|
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
|
|||||||
18
metadata.go
Normal file
18
metadata.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
import (
|
||||||
|
proto "google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HandleMetaProto(p []byte) {
|
||||||
|
m := &NebulaMeta{}
|
||||||
|
err := proto.Unmarshal(p, m)
|
||||||
|
if err != nil {
|
||||||
|
l.Debugf("problem unmarshaling meta message: %s", err)
|
||||||
|
}
|
||||||
|
//fmt.Println(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
@@ -31,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
|
||||||
|
if found {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,13 +20,13 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
// length fails
|
// length fails
|
||||||
err := newPacket([]byte{}, true, p)
|
err := newPacket([]byte{}, true, p)
|
||||||
require.ErrorIs(t, err, ErrPacketTooShort)
|
assert.ErrorIs(t, err, ErrPacketTooShort)
|
||||||
|
|
||||||
err = newPacket([]byte{0x40}, true, p)
|
err = newPacket([]byte{0x40}, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
||||||
|
|
||||||
err = newPacket([]byte{0x60}, true, p)
|
err = newPacket([]byte{0x60}, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// length fail with ip options
|
// length fail with ip options
|
||||||
h := ipv4.Header{
|
h := ipv4.Header{
|
||||||
@@ -40,15 +39,15 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
b, _ := h.Marshal()
|
b, _ := h.Marshal()
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||||
|
|
||||||
// not an ipv4 packet
|
// not an ipv4 packet
|
||||||
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
require.ErrorIs(t, err, ErrUnknownIPVersion)
|
assert.ErrorIs(t, err, ErrUnknownIPVersion)
|
||||||
|
|
||||||
// invalid ihl
|
// invalid ihl
|
||||||
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||||
|
|
||||||
// account for variable ip header length - incoming
|
// account for variable ip header length - incoming
|
||||||
h = ipv4.Header{
|
h = ipv4.Header{
|
||||||
@@ -64,7 +63,7 @@ func Test_newPacket(t *testing.T) {
|
|||||||
b = append(b, []byte{0, 3, 0, 4}...)
|
b = append(b, []byte{0, 3, 0, 4}...)
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
|
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
|
||||||
@@ -86,7 +85,7 @@ func Test_newPacket(t *testing.T) {
|
|||||||
b = append(b, []byte{0, 5, 0, 6}...)
|
b = append(b, []byte{0, 5, 0, 6}...)
|
||||||
err = newPacket(b, false, p)
|
err = newPacket(b, false, p)
|
||||||
|
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(2), p.Protocol)
|
assert.Equal(t, uint8(2), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
|
||||||
@@ -112,10 +111,10 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
FixLengths: false,
|
FixLengths: false,
|
||||||
}
|
}
|
||||||
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = newPacket(buffer.Bytes(), true, p)
|
err = newPacket(buffer.Bytes(), true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A good ICMP packet
|
// A good ICMP packet
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
@@ -135,7 +134,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(buffer.Bytes(), true, p)
|
err = newPacket(buffer.Bytes(), true, p)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -147,7 +146,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b := buffer.Bytes()
|
b := buffer.Bytes()
|
||||||
b[6] = byte(layers.IPProtocolESP)
|
b[6] = byte(layers.IPProtocolESP)
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -159,7 +158,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
b[6] = byte(layers.IPProtocolNoNextHeader)
|
b[6] = byte(layers.IPProtocolNoNextHeader)
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -171,7 +170,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
b[6] = 255 // 255 is a reserved protocol number
|
b[6] = 255 // 255 is a reserved protocol number
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A good UDP packet
|
// A good UDP packet
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
@@ -187,7 +186,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
DstPort: layers.UDPPort(22),
|
DstPort: layers.UDPPort(22),
|
||||||
}
|
}
|
||||||
err = udp.SetNetworkLayerForChecksum(&ip)
|
err = udp.SetNetworkLayerForChecksum(&ip)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
buffer.Clear()
|
buffer.Clear()
|
||||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
|
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||||
@@ -198,7 +197,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// incoming
|
// incoming
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -208,7 +207,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// outgoing
|
// outgoing
|
||||||
err = newPacket(b, false, p)
|
err = newPacket(b, false, p)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
@@ -218,14 +217,14 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// Too short UDP packet
|
// Too short UDP packet
|
||||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// A good TCP packet
|
// A good TCP packet
|
||||||
b[6] = byte(layers.IPProtocolTCP)
|
b[6] = byte(layers.IPProtocolTCP)
|
||||||
|
|
||||||
// incoming
|
// incoming
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -235,7 +234,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// outgoing
|
// outgoing
|
||||||
err = newPacket(b, false, p)
|
err = newPacket(b, false, p)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
@@ -245,7 +244,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// Too short TCP packet
|
// Too short TCP packet
|
||||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// A good UDP packet with an AH header
|
// A good UDP packet with an AH header
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
@@ -280,7 +279,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b = append(b, udpHeader...)
|
b = append(b, udpHeader...)
|
||||||
|
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -291,7 +290,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
// Invalid AH header
|
// Invalid AH header
|
||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_newPacket_ipv6Fragment(t *testing.T) {
|
func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||||
@@ -339,7 +338,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Test first fragment incoming
|
// Test first fragment incoming
|
||||||
err = newPacket(firstFrag, true, p)
|
err = newPacket(firstFrag, true, p)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
@@ -349,7 +348,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Test first fragment outgoing
|
// Test first fragment outgoing
|
||||||
err = newPacket(firstFrag, false, p)
|
err = newPacket(firstFrag, false, p)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
@@ -378,7 +377,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Test second fragment incoming
|
// Test second fragment incoming
|
||||||
err = newPacket(secondFrag, true, p)
|
err = newPacket(secondFrag, true, p)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
@@ -388,7 +387,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Test second fragment outgoing
|
// Test second fragment outgoing
|
||||||
err = newPacket(secondFrag, false, p)
|
err = newPacket(secondFrag, false, p)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
@@ -398,7 +397,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Too short of a fragment packet
|
// Too short of a fragment packet
|
||||||
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
|
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkParseV6(b *testing.B) {
|
func BenchmarkParseV6(b *testing.B) {
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
@@ -12,6 +10,6 @@ type Device interface {
|
|||||||
Activate() error
|
Activate() error
|
||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RouteFor(netip.Addr) netip.Addr
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,14 +11,13 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Route struct {
|
type Route struct {
|
||||||
MTU int
|
MTU int
|
||||||
Metric int
|
Metric int
|
||||||
Cidr netip.Prefix
|
Cidr netip.Prefix
|
||||||
Via routing.Gateways
|
Via netip.Addr
|
||||||
Install bool
|
Install bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,17 +47,15 @@ func (r Route) String() string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
|
||||||
routeTree := new(bart.Table[routing.Gateways])
|
routeTree := new(bart.Table[netip.Addr])
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !allowMTU && r.MTU > 0 {
|
if !allowMTU && r.MTU > 0 {
|
||||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
gateways := r.Via
|
if r.Via.IsValid() {
|
||||||
if len(gateways) > 0 {
|
routeTree.Insert(r.Cidr, r.Via)
|
||||||
routing.CalculateBucketsForGateways(gateways)
|
|
||||||
routeTree.Insert(r.Cidr, gateways)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return routeTree, nil
|
return routeTree, nil
|
||||||
@@ -72,7 +69,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
return []Route{}, nil
|
return []Route{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rawRoutes, ok := r.([]any)
|
rawRoutes, ok := r.([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("tun.routes is not an array")
|
return nil, fmt.Errorf("tun.routes is not an array")
|
||||||
}
|
}
|
||||||
@@ -83,7 +80,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
|
|
||||||
routes := make([]Route, len(rawRoutes))
|
routes := make([]Route, len(rawRoutes))
|
||||||
for i, r := range rawRoutes {
|
for i, r := range rawRoutes {
|
||||||
m, ok := r.(map[string]any)
|
m, ok := r.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
|
return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
|
||||||
}
|
}
|
||||||
@@ -151,7 +148,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
return []Route{}, nil
|
return []Route{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rawRoutes, ok := r.([]any)
|
rawRoutes, ok := r.([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("tun.unsafe_routes is not an array")
|
return nil, fmt.Errorf("tun.unsafe_routes is not an array")
|
||||||
}
|
}
|
||||||
@@ -162,7 +159,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
|
|
||||||
routes := make([]Route, len(rawRoutes))
|
routes := make([]Route, len(rawRoutes))
|
||||||
for i, r := range rawRoutes {
|
for i, r := range rawRoutes {
|
||||||
m, ok := r.(map[string]any)
|
m, ok := r.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
|
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
|
||||||
}
|
}
|
||||||
@@ -204,63 +201,14 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
var gateways routing.Gateways
|
via, ok := rVia.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
|
||||||
|
}
|
||||||
|
|
||||||
switch via := rVia.(type) {
|
viaVpnIp, err := netip.ParseAddr(via)
|
||||||
case string:
|
if err != nil {
|
||||||
viaIp, err := netip.ParseAddr(via)
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
gateways = routing.Gateways{routing.NewGateway(viaIp, 1)}
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
gateways = make(routing.Gateways, len(via))
|
|
||||||
for ig, v := range via {
|
|
||||||
gatewayMap, ok := v.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1)
|
|
||||||
}
|
|
||||||
|
|
||||||
rGateway, ok := gatewayMap["gateway"]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1)
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedGateway, ok := rGateway.(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1)
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayIp, err := netip.ParseAddr(parsedGateway)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rGatewayWeight, ok := gatewayMap["weight"]
|
|
||||||
if !ok {
|
|
||||||
rGatewayWeight = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayWeight, ok := rGatewayWeight.(int)
|
|
||||||
if !ok {
|
|
||||||
_, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 {
|
|
||||||
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rRoute, ok := m["route"]
|
rRoute, ok := m["route"]
|
||||||
@@ -278,7 +226,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := Route{
|
r := Route{
|
||||||
Via: gateways,
|
Via: viaVpnIp,
|
||||||
MTU: mtu,
|
MTU: mtu,
|
||||||
Metric: metric,
|
Metric: metric,
|
||||||
Install: install,
|
Install: install,
|
||||||
|
|||||||
@@ -6,96 +6,94 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_parseRoutes(t *testing.T) {
|
func Test_parseRoutes(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
routes, err := parseRoutes(c, []netip.Prefix{n})
|
routes, err := parseRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Empty(t, routes)
|
assert.Len(t, routes, 0)
|
||||||
|
|
||||||
// not an array
|
// not an array
|
||||||
c.Settings["tun"] = map[string]any{"routes": "hi"}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "tun.routes is not an array")
|
assert.EqualError(t, err, "tun.routes is not an array")
|
||||||
|
|
||||||
// no routes
|
// no routes
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Empty(t, routes)
|
assert.Len(t, routes, 0)
|
||||||
|
|
||||||
// weird route
|
// weird route
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
||||||
|
|
||||||
// no mtu
|
// no mtu
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
||||||
|
|
||||||
// bad mtu
|
// bad mtu
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
// low mtu
|
// low mtu
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
||||||
|
|
||||||
// missing route
|
// missing route
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
||||||
|
|
||||||
// unparsable route
|
// unparsable route
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||||
|
|
||||||
// below network range
|
// below network range
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
|
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
|
||||||
|
|
||||||
// above network range
|
// above network range
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
|
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
|
||||||
|
|
||||||
// Not in multiple ranges
|
// Not in multiple ranges
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
|
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
|
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
|
||||||
|
|
||||||
// happy case
|
// happy case
|
||||||
c.Settings["tun"] = map[string]any{"routes": []any{
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
|
||||||
map[string]any{"mtu": "9000", "route": "10.0.0.0/29"},
|
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
|
||||||
map[string]any{"mtu": "8000", "route": "10.0.0.1/32"},
|
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
|
||||||
}}
|
}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
|
|
||||||
tested := 0
|
tested := 0
|
||||||
@@ -121,140 +119,116 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Empty(t, routes)
|
assert.Len(t, routes, 0)
|
||||||
|
|
||||||
// not an array
|
// not an array
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "tun.unsafe_routes is not an array")
|
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
|
||||||
|
|
||||||
// no routes
|
// no routes
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Empty(t, routes)
|
assert.Len(t, routes, 0)
|
||||||
|
|
||||||
// weird route
|
// weird route
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
||||||
|
|
||||||
// no via
|
// no via
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
||||||
|
|
||||||
// invalid via
|
// invalid via
|
||||||
for _, invalidValue := range []any{
|
for _, invalidValue := range []interface{}{
|
||||||
127, false, nil, 1.0, []string{"1", "2"},
|
127, false, nil, 1.0, []string{"1", "2"},
|
||||||
} {
|
} {
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue))
|
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unparsable list of via
|
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}}
|
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
|
||||||
assert.Nil(t, routes)
|
|
||||||
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string")
|
|
||||||
|
|
||||||
// unparsable via
|
// unparsable via
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
||||||
|
|
||||||
// unparsable gateway
|
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}}
|
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
|
||||||
assert.Nil(t, routes)
|
|
||||||
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP")
|
|
||||||
|
|
||||||
// missing gateway element
|
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}}
|
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
|
||||||
assert.Nil(t, routes)
|
|
||||||
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present")
|
|
||||||
|
|
||||||
// unparsable weight element
|
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}}
|
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
|
||||||
assert.Nil(t, routes)
|
|
||||||
require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer")
|
|
||||||
|
|
||||||
// missing route
|
// missing route
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
||||||
|
|
||||||
// unparsable route
|
// unparsable route
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||||
|
|
||||||
// within network range
|
// within network range
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
||||||
|
|
||||||
// below network range
|
// below network range
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// above network range
|
// above network range
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// no mtu
|
// no mtu
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Equal(t, 0, routes[0].MTU)
|
assert.Equal(t, 0, routes[0].MTU)
|
||||||
|
|
||||||
// bad mtu
|
// bad mtu
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
// low mtu
|
// low mtu
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
||||||
|
|
||||||
// bad install
|
// bad install
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
// happy case
|
// happy case
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
||||||
map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
|
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
|
||||||
map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
|
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
|
||||||
map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
||||||
map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
||||||
}}
|
}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, routes, 4)
|
assert.Len(t, routes, 4)
|
||||||
|
|
||||||
tested := 0
|
tested := 0
|
||||||
@@ -286,119 +260,38 @@ func Test_makeRouteTree(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
||||||
map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
||||||
map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
||||||
}}
|
}}
|
||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(l, routes, true)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("1.0.0.2")
|
ip, err := netip.ParseAddr("1.0.0.2")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
r, ok := routeTree.Lookup(ip)
|
r, ok := routeTree.Lookup(ip)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
nip, err := netip.ParseAddr("192.168.0.1")
|
nip, err := netip.ParseAddr("192.168.0.1")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, nip, r[0].Addr())
|
assert.Equal(t, nip, r)
|
||||||
|
|
||||||
ip, err = netip.ParseAddr("1.0.0.1")
|
ip, err = netip.ParseAddr("1.0.0.1")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
r, ok = routeTree.Lookup(ip)
|
r, ok = routeTree.Lookup(ip)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
nip, err = netip.ParseAddr("192.168.0.2")
|
nip, err = netip.ParseAddr("192.168.0.2")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, nip, r[0].Addr())
|
assert.Equal(t, nip, r)
|
||||||
|
|
||||||
ip, err = netip.ParseAddr("1.1.0.1")
|
ip, err = netip.ParseAddr("1.1.0.1")
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
r, ok = routeTree.Lookup(ip)
|
r, ok = routeTree.Lookup(ip)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
c := config.NewC(l)
|
|
||||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c.Settings["tun"] = map[string]any{
|
|
||||||
"unsafe_routes": []any{
|
|
||||||
map[string]any{
|
|
||||||
"route": "192.168.86.0/24",
|
|
||||||
"via": "192.168.100.10",
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"route": "192.168.87.0/24",
|
|
||||||
"via": []any{
|
|
||||||
map[string]any{
|
|
||||||
"gateway": "10.0.0.1",
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"gateway": "10.0.0.2",
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"gateway": "10.0.0.3",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"route": "192.168.89.0/24",
|
|
||||||
"via": []any{
|
|
||||||
map[string]any{
|
|
||||||
"gateway": "10.0.0.1",
|
|
||||||
"weight": 10,
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"gateway": "10.0.0.2",
|
|
||||||
"weight": 5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Len(t, routes, 3)
|
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("192.168.86.1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
r, ok := routeTree.Lookup(ip)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
nip, err := netip.ParseAddr("192.168.100.10")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, nip, r[0].Addr())
|
|
||||||
|
|
||||||
ip, err = netip.ParseAddr("192.168.87.1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
r, ok = routeTree.Lookup(ip)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1),
|
|
||||||
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1),
|
|
||||||
routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)}
|
|
||||||
|
|
||||||
routing.CalculateBucketsForGateways(expectedGateways)
|
|
||||||
assert.ElementsMatch(t, expectedGateways, r)
|
|
||||||
|
|
||||||
ip, err = netip.ParseAddr("192.168.89.1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
r, ok = routeTree.Lookup(ip)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10),
|
|
||||||
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)}
|
|
||||||
|
|
||||||
routing.CalculateBucketsForGateways(expectedGateways)
|
|
||||||
assert.ElementsMatch(t, expectedGateways, r)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,7 +21,7 @@ type tun struct {
|
|||||||
fd int
|
fd int
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +56,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
|
|||||||
return nil, fmt.Errorf("newTun not supported in Android")
|
return nil, fmt.Errorf("newTun not supported in Android")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
@@ -29,7 +28,7 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
DefaultMTU int
|
DefaultMTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
linkAddr *netroute.LinkAddr
|
linkAddr *netroute.LinkAddr
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
@@ -295,7 +294,6 @@ func (t *tun) activate6(network netip.Prefix) error {
|
|||||||
Vltime: 0xffffffff,
|
Vltime: 0xffffffff,
|
||||||
Pltime: 0xffffffff,
|
Pltime: 0xffffffff,
|
||||||
},
|
},
|
||||||
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
|
|
||||||
Flags: _IN6_IFF_NODAD,
|
Flags: _IN6_IFF_NODAD,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,12 +341,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, ok := t.routeTree.Load().Lookup(ip)
|
r, ok := t.routeTree.Load().Lookup(ip)
|
||||||
if ok {
|
if ok {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
return routing.Gateways{}
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the LinkAddr for the interface of the given name
|
// Get the LinkAddr for the interface of the given name
|
||||||
@@ -383,7 +381,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if !r.Via.IsValid() || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -394,7 +392,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
t.l.WithField("route", r.Cidr).
|
t.l.WithField("route", r.Cidr).
|
||||||
Warnf("unable to add unsafe_route, identical route already exists")
|
Warnf("unable to add unsafe_route, identical route already exists")
|
||||||
} else {
|
} else {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type disabledTun struct {
|
type disabledTun struct {
|
||||||
@@ -44,8 +43,8 @@ func (*disabledTun) Activate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
|
||||||
return routing.Gateways{}
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) Networks() []netip.Prefix {
|
func (t *disabledTun) Networks() []netip.Prefix {
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -51,7 +50,7 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
@@ -243,7 +242,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -263,7 +262,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if !r.Via.IsValid() || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -271,7 +270,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,7 +23,7 @@ type tun struct {
|
|||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,7 +79,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
@@ -35,7 +34,7 @@ type tun struct {
|
|||||||
ioctlFd uintptr
|
ioctlFd uintptr
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
routeChan chan struct{}
|
routeChan chan struct{}
|
||||||
useSystemRoutes bool
|
useSystemRoutes bool
|
||||||
|
|
||||||
@@ -232,7 +231,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -464,7 +463,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
|
|
||||||
err := netlink.RouteReplace(&nr)
|
err := netlink.RouteReplace(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
@@ -551,7 +550,20 @@ func (t *tun) watchRoutes() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||||
|
if r.Gw == nil {
|
||||||
|
// Not a gateway route, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||||
|
if !ok {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gwAddr = gwAddr.Unmap()
|
||||||
withinNetworks := false
|
withinNetworks := false
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
if t.vpnNetworks[i].Contains(gwAddr) {
|
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||||
@@ -559,68 +571,9 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !withinNetworks {
|
||||||
return withinNetworks
|
// Gateway isn't in our overlay network, ignore
|
||||||
}
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
|
||||||
|
|
||||||
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|
||||||
|
|
||||||
var gateways routing.Gateways
|
|
||||||
|
|
||||||
link, err := netlink.LinkByName(t.Device)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
|
|
||||||
return gateways
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
|
||||||
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
|
||||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
|
||||||
if !ok {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
|
||||||
} else {
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
|
|
||||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
|
||||||
} else {
|
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range r.MultiPath {
|
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
|
||||||
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
|
||||||
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
|
||||||
if !ok {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
|
||||||
} else {
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
|
|
||||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
|
||||||
} else {
|
|
||||||
// p.Hops+1 = weight of the route
|
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
routing.CalculateBucketsForGateways(gateways)
|
|
||||||
return gateways
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|
||||||
|
|
||||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
|
||||||
|
|
||||||
if len(gateways) == 0 {
|
|
||||||
// No gateways relevant to our network, no routing changes required.
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -636,12 +589,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
newTree := t.routeTree.Load().Clone()
|
newTree := t.routeTree.Load().Clone()
|
||||||
|
|
||||||
if r.Type == unix.RTM_NEWROUTE {
|
if r.Type == unix.RTM_NEWROUTE {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
|
||||||
newTree.Insert(dst, gateways)
|
newTree.Insert(dst, gwAddr)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
|
||||||
newTree.Delete(dst)
|
newTree.Delete(dst)
|
||||||
|
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
||||||
}
|
}
|
||||||
t.routeTree.Store(newTree)
|
t.routeTree.Store(newTree)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,7 +31,7 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
@@ -109,26 +108,34 @@ func (t *tun) addIp(cidr netip.Prefix) error {
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
// TODO use syscalls instead of exec.Command
|
// TODO use syscalls instead of exec.Command
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
if cidr.Addr().Is6() {
|
||||||
t.l.Debug("command: ", cmd.String())
|
cmd := exec.Command("/sbin/ifconfig", t.Device, "inet6", cidr.Addr().String(), "prefixlen", strconv.Itoa(cidr.Bits()), "alias")
|
||||||
if err = cmd.Run(); err != nil {
|
t.l.Debug("command: ", cmd.String())
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err = cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
return nil
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsafe path routes
|
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
@@ -138,7 +145,15 @@ func (t *tun) Activate() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
cmd := exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||||
|
t.l.Debug("command: ", cmd.String())
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run '%s': %s", cmd, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
@@ -178,7 +193,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -198,7 +213,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if !r.Via.IsValid() || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -206,7 +221,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,7 +25,7 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
@@ -159,7 +158,7 @@ func (t *tun) Activate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -167,7 +166,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if !r.Via.IsValid() || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -175,7 +174,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -13,14 +13,13 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TestTun struct {
|
type TestTun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *bart.Table[routing.Gateways]
|
routeTree *bart.Table[netip.Addr]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
@@ -87,7 +86,7 @@ func (t *TestTun) Get(block bool) []byte {
|
|||||||
// Below this is boilerplate implementation to make nebula actually work
|
// Below this is boilerplate implementation to make nebula actually work
|
||||||
//********************************************************************************************************************//
|
//********************************************************************************************************************//
|
||||||
|
|
||||||
func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Lookup(ip)
|
r, _ := t.routeTree.Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/slackhq/nebula/wintun"
|
"github.com/slackhq/nebula/wintun"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
@@ -32,7 +31,7 @@ type winTun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
tun *wintun.NativeTun
|
tun *wintun.NativeTun
|
||||||
@@ -148,18 +147,15 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
|||||||
foundDefault4 := false
|
foundDefault4 := false
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if !r.Via.IsValid() || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add our unsafe route
|
// Add our unsafe route
|
||||||
// Windows does not support multipath routes natively, so we install only a single route.
|
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
|
||||||
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
|
|
||||||
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
|
|
||||||
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
continue
|
continue
|
||||||
@@ -202,8 +198,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// See comment on luid.AddRoute
|
err := luid.DeleteRoute(r.Cidr, r.Via)
|
||||||
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
@@ -213,7 +208,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
@@ -39,13 +38,9 @@ type UserDevice struct {
|
|||||||
func (d *UserDevice) Activate() error {
|
func (d *UserDevice) Activate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
|
||||||
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
|
func (d *UserDevice) Name() string { return "faketun0" }
|
||||||
func (d *UserDevice) Name() string { return "faketun0" }
|
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
|
||||||
func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|
||||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|||||||
59
pki.go
59
pki.go
@@ -33,16 +33,16 @@ type CertState struct {
|
|||||||
v2Cert cert.Certificate
|
v2Cert cert.Certificate
|
||||||
v2HandshakeBytes []byte
|
v2HandshakeBytes []byte
|
||||||
|
|
||||||
initiatingVersion cert.Version
|
defaultVersion cert.Version
|
||||||
privateKey []byte
|
privateKey []byte
|
||||||
pkcs11Backed bool
|
pkcs11Backed bool
|
||||||
cipher string
|
cipher string
|
||||||
|
|
||||||
myVpnNetworks []netip.Prefix
|
myVpnNetworks []netip.Prefix
|
||||||
myVpnNetworksTable *bart.Lite
|
myVpnNetworksTable *bart.Table[struct{}]
|
||||||
myVpnAddrs []netip.Addr
|
myVpnAddrs []netip.Addr
|
||||||
myVpnAddrsTable *bart.Lite
|
myVpnAddrsTable *bart.Table[struct{}]
|
||||||
myVpnBroadcastAddrsTable *bart.Lite
|
myVpnBroadcastAddrsTable *bart.Table[struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
||||||
@@ -173,7 +173,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
|
|
||||||
p.cs.Store(newState)
|
p.cs.Store(newState)
|
||||||
|
|
||||||
//TODO: CERT-V2 newState needs a stringer that does json
|
|
||||||
if initial {
|
if initial {
|
||||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||||
} else {
|
} else {
|
||||||
@@ -194,7 +193,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
||||||
c := cs.getCertificate(cs.initiatingVersion)
|
c := cs.getCertificate(cs.defaultVersion)
|
||||||
if c == nil {
|
if c == nil {
|
||||||
panic("No default certificate found")
|
panic("No default certificate found")
|
||||||
}
|
}
|
||||||
@@ -317,37 +316,37 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
|||||||
return nil, errors.New("no certificates found in pki.cert")
|
return nil, errors.New("no certificates found in pki.cert")
|
||||||
}
|
}
|
||||||
|
|
||||||
useInitiatingVersion := uint32(1)
|
useDefaultVersion := uint32(1)
|
||||||
if v1 == nil {
|
if v1 == nil {
|
||||||
// The only condition that requires v2 as the default is if only a v2 certificate is present
|
// The only condition that requires v2 as the default is if only a v2 certificate is present
|
||||||
// We do this to avoid having to configure it specifically in the config file
|
// We do this to avoid having to configure it specifically in the config file
|
||||||
useInitiatingVersion = 2
|
useDefaultVersion = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
rawInitiatingVersion := c.GetUint32("pki.initiating_version", useInitiatingVersion)
|
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
|
||||||
var initiatingVersion cert.Version
|
var defaultVersion cert.Version
|
||||||
switch rawInitiatingVersion {
|
switch rawDefaultVersion {
|
||||||
case 1:
|
case 1:
|
||||||
if v1 == nil {
|
if v1 == nil {
|
||||||
return nil, fmt.Errorf("can not use pki.initiating_version 1 without a v1 certificate in pki.cert")
|
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
|
||||||
}
|
}
|
||||||
initiatingVersion = cert.Version1
|
defaultVersion = cert.Version1
|
||||||
case 2:
|
case 2:
|
||||||
initiatingVersion = cert.Version2
|
defaultVersion = cert.Version2
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
|
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
|
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||||
cs := CertState{
|
cs := CertState{
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
pkcs11Backed: pkcs11backed,
|
pkcs11Backed: pkcs11backed,
|
||||||
myVpnNetworksTable: new(bart.Lite),
|
myVpnNetworksTable: new(bart.Table[struct{}]),
|
||||||
myVpnAddrsTable: new(bart.Lite),
|
myVpnAddrsTable: new(bart.Table[struct{}]),
|
||||||
myVpnBroadcastAddrsTable: new(bart.Lite),
|
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
|
||||||
}
|
}
|
||||||
|
|
||||||
if v1 != nil && v2 != nil {
|
if v1 != nil && v2 != nil {
|
||||||
@@ -361,7 +360,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
|
|
||||||
//TODO: CERT-V2 make sure v2 has v1s address
|
//TODO: CERT-V2 make sure v2 has v1s address
|
||||||
|
|
||||||
cs.initiatingVersion = dv
|
cs.defaultVersion = dv
|
||||||
}
|
}
|
||||||
|
|
||||||
if v1 != nil {
|
if v1 != nil {
|
||||||
@@ -380,8 +379,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
cs.v1Cert = v1
|
cs.v1Cert = v1
|
||||||
cs.v1HandshakeBytes = v1hs
|
cs.v1HandshakeBytes = v1hs
|
||||||
|
|
||||||
if cs.initiatingVersion == 0 {
|
if cs.defaultVersion == 0 {
|
||||||
cs.initiatingVersion = cert.Version1
|
cs.defaultVersion = cert.Version1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -401,8 +400,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
cs.v2Cert = v2
|
cs.v2Cert = v2
|
||||||
cs.v2HandshakeBytes = v2hs
|
cs.v2HandshakeBytes = v2hs
|
||||||
|
|
||||||
if cs.initiatingVersion == 0 {
|
if cs.defaultVersion == 0 {
|
||||||
cs.initiatingVersion = cert.Version2
|
cs.defaultVersion = cert.Version2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -415,16 +414,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
|
|
||||||
for _, network := range crt.Networks() {
|
for _, network := range crt.Networks() {
|
||||||
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
|
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
|
||||||
cs.myVpnNetworksTable.Insert(network)
|
cs.myVpnNetworksTable.Insert(network, struct{}{})
|
||||||
|
|
||||||
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
|
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
|
||||||
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()))
|
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
|
||||||
|
|
||||||
if network.Addr().Is4() {
|
if network.Addr().Is4() {
|
||||||
addr := network.Masked().Addr().As4()
|
addr := network.Masked().Addr().As4()
|
||||||
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
|
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
|
||||||
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
|
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
|
||||||
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()))
|
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewPunchyFromConfig(t *testing.T) {
|
func TestNewPunchyFromConfig(t *testing.T) {
|
||||||
@@ -16,39 +15,39 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test defaults
|
// Test defaults
|
||||||
p := NewPunchyFromConfig(l, c)
|
p := NewPunchyFromConfig(l, c)
|
||||||
assert.False(t, p.GetPunch())
|
assert.Equal(t, false, p.GetPunch())
|
||||||
assert.False(t, p.GetRespond())
|
assert.Equal(t, false, p.GetRespond())
|
||||||
assert.Equal(t, time.Second, p.GetDelay())
|
assert.Equal(t, time.Second, p.GetDelay())
|
||||||
assert.Equal(t, 5*time.Second, p.GetRespondDelay())
|
assert.Equal(t, 5*time.Second, p.GetRespondDelay())
|
||||||
|
|
||||||
// punchy deprecation
|
// punchy deprecation
|
||||||
c.Settings["punchy"] = true
|
c.Settings["punchy"] = true
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.True(t, p.GetPunch())
|
assert.Equal(t, true, p.GetPunch())
|
||||||
|
|
||||||
// punchy.punch
|
// punchy.punch
|
||||||
c.Settings["punchy"] = map[string]any{"punch": true}
|
c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.True(t, p.GetPunch())
|
assert.Equal(t, true, p.GetPunch())
|
||||||
|
|
||||||
// punch_back deprecation
|
// punch_back deprecation
|
||||||
c.Settings["punch_back"] = true
|
c.Settings["punch_back"] = true
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.True(t, p.GetRespond())
|
assert.Equal(t, true, p.GetRespond())
|
||||||
|
|
||||||
// punchy.respond
|
// punchy.respond
|
||||||
c.Settings["punchy"] = map[string]any{"respond": true}
|
c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
|
||||||
c.Settings["punch_back"] = false
|
c.Settings["punch_back"] = false
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.True(t, p.GetRespond())
|
assert.Equal(t, true, p.GetRespond())
|
||||||
|
|
||||||
// punchy.delay
|
// punchy.delay
|
||||||
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, time.Minute, p.GetDelay())
|
assert.Equal(t, time.Minute, p.GetDelay())
|
||||||
|
|
||||||
// punchy.respond_delay
|
// punchy.respond_delay
|
||||||
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
||||||
}
|
}
|
||||||
@@ -57,22 +56,22 @@ func TestPunchy_reload(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
delay, _ := time.ParseDuration("1m")
|
delay, _ := time.ParseDuration("1m")
|
||||||
require.NoError(t, c.LoadString(`
|
assert.NoError(t, c.LoadString(`
|
||||||
punchy:
|
punchy:
|
||||||
delay: 1m
|
delay: 1m
|
||||||
respond: false
|
respond: false
|
||||||
`))
|
`))
|
||||||
p := NewPunchyFromConfig(l, c)
|
p := NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, delay, p.GetDelay())
|
assert.Equal(t, delay, p.GetDelay())
|
||||||
assert.False(t, p.GetRespond())
|
assert.Equal(t, false, p.GetRespond())
|
||||||
|
|
||||||
newDelay, _ := time.ParseDuration("10m")
|
newDelay, _ := time.ParseDuration("10m")
|
||||||
require.NoError(t, c.ReloadConfigString(`
|
assert.NoError(t, c.ReloadConfigString(`
|
||||||
punchy:
|
punchy:
|
||||||
delay: 10m
|
delay: 10m
|
||||||
respond: true
|
respond: true
|
||||||
`))
|
`))
|
||||||
p.reload(c, false)
|
p.reload(c, false)
|
||||||
assert.Equal(t, newDelay, p.GetDelay())
|
assert.Equal(t, newDelay, p.GetDelay())
|
||||||
assert.True(t, p.GetRespond())
|
assert.Equal(t, true, p.GetRespond())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -241,13 +241,15 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
logMsg.Info("handleCreateRelayRequest")
|
logMsg.Info("handleCreateRelayRequest")
|
||||||
// Is the source of the relay me? This should never happen, but did happen due to
|
// Is the source of the relay me? This should never happen, but did happen due to
|
||||||
// an issue migrating relays over to newly re-handshaked host info objects.
|
// an issue migrating relays over to newly re-handshaked host info objects.
|
||||||
if f.myVpnAddrsTable.Contains(from) {
|
_, found := f.myVpnAddrsTable.Lookup(from)
|
||||||
|
if found {
|
||||||
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is the target of the relay me?
|
// Is the target of the relay me?
|
||||||
if f.myVpnAddrsTable.Contains(target) {
|
_, found = f.myVpnAddrsTable.Lookup(target)
|
||||||
|
if found {
|
||||||
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
|
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
|
||||||
if ok {
|
if ok {
|
||||||
switch existingRelay.State {
|
switch existingRelay.State {
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Hashes the packet source and destination port and always returns a positive integer
|
|
||||||
// Based on 'Prospecting for Hash Functions'
|
|
||||||
// - https://nullprogram.com/blog/2018/07/31/
|
|
||||||
// - https://github.com/skeeto/hash-prospector
|
|
||||||
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
|
|
||||||
func hashPacket(p *firewall.Packet) int {
|
|
||||||
x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
|
|
||||||
x ^= x >> 16
|
|
||||||
x *= 0x21f0aaad
|
|
||||||
x ^= x >> 15
|
|
||||||
x *= 0xd35a2d97
|
|
||||||
x ^= x >> 15
|
|
||||||
|
|
||||||
return int(x) & 0x7FFFFFFF
|
|
||||||
}
|
|
||||||
|
|
||||||
// For this function to work correctly it requires that the buckets for the gateways have been calculated
|
|
||||||
// If the contract is violated balancing will not work properly and the second return value will return false
|
|
||||||
func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) {
|
|
||||||
hash := hashPacket(fwPacket)
|
|
||||||
|
|
||||||
for i := range gateways {
|
|
||||||
if hash <= gateways[i].BucketUpperBound() {
|
|
||||||
return gateways[i].Addr(), true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If you land here then the buckets for the gateways are not properly calculated
|
|
||||||
// Fallback to random routing and let the caller know
|
|
||||||
return gateways[hash%len(gateways)].Addr(), false
|
|
||||||
}
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPacketsAreBalancedEqually(t *testing.T) {
|
|
||||||
|
|
||||||
gateways := []Gateway{}
|
|
||||||
|
|
||||||
gw1Addr := netip.MustParseAddr("1.0.0.1")
|
|
||||||
gw2Addr := netip.MustParseAddr("1.0.0.2")
|
|
||||||
gw3Addr := netip.MustParseAddr("1.0.0.3")
|
|
||||||
|
|
||||||
gateways = append(gateways, NewGateway(gw1Addr, 1))
|
|
||||||
gateways = append(gateways, NewGateway(gw2Addr, 1))
|
|
||||||
gateways = append(gateways, NewGateway(gw3Addr, 1))
|
|
||||||
|
|
||||||
CalculateBucketsForGateways(gateways)
|
|
||||||
|
|
||||||
gw1count := 0
|
|
||||||
gw2count := 0
|
|
||||||
gw3count := 0
|
|
||||||
|
|
||||||
iterationCount := uint16(65535)
|
|
||||||
for i := uint16(0); i < iterationCount; i++ {
|
|
||||||
packet := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("192.168.1.1"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
|
|
||||||
LocalPort: i,
|
|
||||||
RemotePort: 65535 - i,
|
|
||||||
Protocol: 6, // TCP
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedGw, ok := BalancePacket(&packet, gateways)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
switch selectedGw {
|
|
||||||
case gw1Addr:
|
|
||||||
gw1count += 1
|
|
||||||
case gw2Addr:
|
|
||||||
gw2count += 1
|
|
||||||
case gw3Addr:
|
|
||||||
gw3count += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert packets are balanced, allow variation of up to 100 packets per gateway
|
|
||||||
assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
|
|
||||||
assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
|
|
||||||
assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPacketsAreBalancedByPriority(t *testing.T) {
|
|
||||||
|
|
||||||
gateways := []Gateway{}
|
|
||||||
|
|
||||||
gw1Addr := netip.MustParseAddr("1.0.0.1")
|
|
||||||
gw2Addr := netip.MustParseAddr("1.0.0.2")
|
|
||||||
|
|
||||||
gateways = append(gateways, NewGateway(gw1Addr, 10))
|
|
||||||
gateways = append(gateways, NewGateway(gw2Addr, 5))
|
|
||||||
|
|
||||||
CalculateBucketsForGateways(gateways)
|
|
||||||
|
|
||||||
gw1count := 0
|
|
||||||
gw2count := 0
|
|
||||||
|
|
||||||
iterationCount := uint16(65535)
|
|
||||||
for i := uint16(0); i < iterationCount; i++ {
|
|
||||||
packet := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("192.168.1.1"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
|
|
||||||
LocalPort: i,
|
|
||||||
RemotePort: 65535 - i,
|
|
||||||
Protocol: 6, // TCP
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedGw, ok := BalancePacket(&packet, gateways)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
switch selectedGw {
|
|
||||||
case gw1Addr:
|
|
||||||
gw1count += 1
|
|
||||||
case gw2Addr:
|
|
||||||
gw2count += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
iterationCountAsFloat := float32(iterationCount)
|
|
||||||
|
|
||||||
assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count)
|
|
||||||
assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) {
|
|
||||||
gateways := []Gateway{}
|
|
||||||
|
|
||||||
gw1Addr := netip.MustParseAddr("1.0.0.1")
|
|
||||||
gw2Addr := netip.MustParseAddr("1.0.0.2")
|
|
||||||
|
|
||||||
gateways = append(gateways, NewGateway(gw1Addr, 10))
|
|
||||||
gateways = append(gateways, NewGateway(gw2Addr, 5))
|
|
||||||
|
|
||||||
iterationCount := uint16(65535)
|
|
||||||
gw1count := 0
|
|
||||||
gw2count := 0
|
|
||||||
|
|
||||||
for i := uint16(0); i < iterationCount; i++ {
|
|
||||||
packet := firewall.Packet{
|
|
||||||
LocalAddr: netip.MustParseAddr("192.168.1.1"),
|
|
||||||
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
|
|
||||||
LocalPort: i,
|
|
||||||
RemotePort: 65535 - i,
|
|
||||||
Protocol: 6, // TCP
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedGw, ok := BalancePacket(&packet, gateways)
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
switch selectedGw {
|
|
||||||
case gw1Addr:
|
|
||||||
gw1count += 1
|
|
||||||
case gw2Addr:
|
|
||||||
gw2count += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, int(iterationCount), (gw1count + gw2count))
|
|
||||||
assert.NotEqual(t, 0, gw1count)
|
|
||||||
assert.NotEqual(t, 0, gw2count)
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Sentinal value
|
|
||||||
BucketNotCalculated = -1
|
|
||||||
)
|
|
||||||
|
|
||||||
type Gateways []Gateway
|
|
||||||
|
|
||||||
func (g Gateways) String() string {
|
|
||||||
str := ""
|
|
||||||
for i, gw := range g {
|
|
||||||
str += gw.String()
|
|
||||||
if i < len(g)-1 {
|
|
||||||
str += ", "
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return str
|
|
||||||
}
|
|
||||||
|
|
||||||
type Gateway struct {
|
|
||||||
addr netip.Addr
|
|
||||||
weight int
|
|
||||||
bucketUpperBound int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGateway(addr netip.Addr, weight int) Gateway {
|
|
||||||
return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) BucketUpperBound() int {
|
|
||||||
return g.bucketUpperBound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) Addr() netip.Addr {
|
|
||||||
return g.addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) String() string {
|
|
||||||
return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Divide and round to nearest integer
|
|
||||||
func divideAndRound(v uint64, d uint64) uint64 {
|
|
||||||
var tmp uint64 = v + d/2
|
|
||||||
return tmp / d
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel.
|
|
||||||
// After this function returns each gateway will have a
|
|
||||||
// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX)
|
|
||||||
func CalculateBucketsForGateways(gateways []Gateway) {
|
|
||||||
|
|
||||||
var totalWeight int = 0
|
|
||||||
for i := range gateways {
|
|
||||||
totalWeight += gateways[i].weight
|
|
||||||
}
|
|
||||||
|
|
||||||
var loopWeight int = 0
|
|
||||||
for i := range gateways {
|
|
||||||
loopWeight += gateways[i].weight
|
|
||||||
gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRebalance3_2Split(t *testing.T) {
|
|
||||||
gateways := []Gateway{}
|
|
||||||
|
|
||||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10})
|
|
||||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5})
|
|
||||||
|
|
||||||
CalculateBucketsForGateways(gateways)
|
|
||||||
|
|
||||||
assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2
|
|
||||||
assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRebalanceEqualSplit(t *testing.T) {
|
|
||||||
gateways := []Gateway{}
|
|
||||||
|
|
||||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
|
|
||||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
|
|
||||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
|
|
||||||
|
|
||||||
CalculateBucketsForGateways(gateways)
|
|
||||||
|
|
||||||
assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3
|
|
||||||
assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2
|
|
||||||
assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX
|
|
||||||
}
|
|
||||||
@@ -13,10 +13,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m map[string]interface{}
|
||||||
|
|
||||||
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
|
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
|
||||||
|
|||||||
92
ssh.go
92
ssh.go
@@ -124,10 +124,10 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
rawKeys := c.Get("sshd.authorized_users")
|
rawKeys := c.Get("sshd.authorized_users")
|
||||||
keys, ok := rawKeys.([]any)
|
keys, ok := rawKeys.([]interface{})
|
||||||
if ok {
|
if ok {
|
||||||
for _, rk := range keys {
|
for _, rk := range keys {
|
||||||
kDef, ok := rk.(map[string]any)
|
kDef, ok := rk.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
|
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
|
||||||
continue
|
continue
|
||||||
@@ -148,7 +148,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
case []any:
|
case []interface{}:
|
||||||
for _, subK := range v {
|
for _, subK := range v {
|
||||||
sk, ok := subK.(string)
|
sk, ok := subK.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -190,7 +190,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "list-hostmap",
|
Name: "list-hostmap",
|
||||||
ShortDescription: "List all known previously connected hosts",
|
ShortDescription: "List all known previously connected hosts",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshListHostMapFlags{}
|
s := sshListHostMapFlags{}
|
||||||
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
||||||
@@ -198,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
|
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshListHostMap(f.hostMap, fs, w)
|
return sshListHostMap(f.hostMap, fs, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -206,7 +206,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "list-pending-hostmap",
|
Name: "list-pending-hostmap",
|
||||||
ShortDescription: "List all handshaking hosts",
|
ShortDescription: "List all handshaking hosts",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshListHostMapFlags{}
|
s := sshListHostMapFlags{}
|
||||||
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
||||||
@@ -214,7 +214,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
|
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshListHostMap(f.handshakeManager, fs, w)
|
return sshListHostMap(f.handshakeManager, fs, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -222,14 +222,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "list-lighthouse-addrmap",
|
Name: "list-lighthouse-addrmap",
|
||||||
ShortDescription: "List all lighthouse map entries",
|
ShortDescription: "List all lighthouse map entries",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshListHostMapFlags{}
|
s := sshListHostMapFlags{}
|
||||||
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
||||||
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
|
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshListLighthouseMap(f.lightHouse, fs, w)
|
return sshListLighthouseMap(f.lightHouse, fs, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -237,7 +237,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "reload",
|
Name: "reload",
|
||||||
ShortDescription: "Reloads configuration from disk, same as sending HUP to the process",
|
ShortDescription: "Reloads configuration from disk, same as sending HUP to the process",
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshReload(c, w)
|
return sshReload(c, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -251,7 +251,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "stop-cpu-profile",
|
Name: "stop-cpu-profile",
|
||||||
ShortDescription: "Stops a cpu profile and writes output to the previously provided file",
|
ShortDescription: "Stops a cpu profile and writes output to the previously provided file",
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
pprof.StopCPUProfile()
|
pprof.StopCPUProfile()
|
||||||
return w.WriteLine("If a CPU profile was running it is now stopped")
|
return w.WriteLine("If a CPU profile was running it is now stopped")
|
||||||
},
|
},
|
||||||
@@ -278,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "log-level",
|
Name: "log-level",
|
||||||
ShortDescription: "Gets or sets the current log level",
|
ShortDescription: "Gets or sets the current log level",
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshLogLevel(l, fs, a, w)
|
return sshLogLevel(l, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -286,7 +286,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "log-format",
|
Name: "log-format",
|
||||||
ShortDescription: "Gets or sets the current log format",
|
ShortDescription: "Gets or sets the current log format",
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshLogFormat(l, fs, a, w)
|
return sshLogFormat(l, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -294,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "version",
|
Name: "version",
|
||||||
ShortDescription: "Prints the currently running version of nebula",
|
ShortDescription: "Prints the currently running version of nebula",
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshVersion(f, fs, a, w)
|
return sshVersion(f, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -302,14 +302,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "device-info",
|
Name: "device-info",
|
||||||
ShortDescription: "Prints information about the network device.",
|
ShortDescription: "Prints information about the network device.",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshDeviceInfoFlags{}
|
s := sshDeviceInfoFlags{}
|
||||||
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
||||||
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
|
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshDeviceInfo(f, fs, w)
|
return sshDeviceInfo(f, fs, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -317,7 +317,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "print-cert",
|
Name: "print-cert",
|
||||||
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
|
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshPrintCertFlags{}
|
s := sshPrintCertFlags{}
|
||||||
fl.BoolVar(&s.Json, "json", false, "outputs as json")
|
fl.BoolVar(&s.Json, "json", false, "outputs as json")
|
||||||
@@ -325,7 +325,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty")
|
fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshPrintCert(f, fs, a, w)
|
return sshPrintCert(f, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -333,13 +333,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "print-tunnel",
|
Name: "print-tunnel",
|
||||||
ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
|
ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshPrintTunnelFlags{}
|
s := sshPrintTunnelFlags{}
|
||||||
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
|
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshPrintTunnel(f, fs, a, w)
|
return sshPrintTunnel(f, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -347,13 +347,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "print-relays",
|
Name: "print-relays",
|
||||||
ShortDescription: "Prints json details about all relay info",
|
ShortDescription: "Prints json details about all relay info",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshPrintTunnelFlags{}
|
s := sshPrintTunnelFlags{}
|
||||||
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
|
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshPrintRelays(f, fs, a, w)
|
return sshPrintRelays(f, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -361,13 +361,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "change-remote",
|
Name: "change-remote",
|
||||||
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
|
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshChangeRemoteFlags{}
|
s := sshChangeRemoteFlags{}
|
||||||
fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port")
|
fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshChangeRemote(f, fs, a, w)
|
return sshChangeRemote(f, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -375,13 +375,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
ssh.RegisterCommand(&sshd.Command{
|
ssh.RegisterCommand(&sshd.Command{
|
||||||
Name: "close-tunnel",
|
Name: "close-tunnel",
|
||||||
ShortDescription: "Closes a tunnel for the provided vpn addr",
|
ShortDescription: "Closes a tunnel for the provided vpn addr",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshCloseTunnelFlags{}
|
s := sshCloseTunnelFlags{}
|
||||||
fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down")
|
fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshCloseTunnel(f, fs, a, w)
|
return sshCloseTunnel(f, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -390,13 +390,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
Name: "create-tunnel",
|
Name: "create-tunnel",
|
||||||
ShortDescription: "Creates a tunnel for the provided vpn address",
|
ShortDescription: "Creates a tunnel for the provided vpn address",
|
||||||
Help: "The lighthouses will be queried for real addresses but you can provide one as well.",
|
Help: "The lighthouses will be queried for real addresses but you can provide one as well.",
|
||||||
Flags: func() (*flag.FlagSet, any) {
|
Flags: func() (*flag.FlagSet, interface{}) {
|
||||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
s := sshCreateTunnelFlags{}
|
s := sshCreateTunnelFlags{}
|
||||||
fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ")
|
fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ")
|
||||||
return fl, &s
|
return fl, &s
|
||||||
},
|
},
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshCreateTunnel(f, fs, a, w)
|
return sshCreateTunnel(f, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -405,13 +405,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
|||||||
Name: "query-lighthouse",
|
Name: "query-lighthouse",
|
||||||
ShortDescription: "Query the lighthouses for the provided vpn address",
|
ShortDescription: "Query the lighthouses for the provided vpn address",
|
||||||
Help: "This command is asynchronous. Only currently known udp addresses will be printed.",
|
Help: "This command is asynchronous. Only currently known udp addresses will be printed.",
|
||||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return sshQueryLighthouse(f, fs, a, w)
|
return sshQueryLighthouse(f, fs, a, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error {
|
func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
|
||||||
fs, ok := a.(*sshListHostMapFlags)
|
fs, ok := a.(*sshListHostMapFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -451,7 +451,7 @@ func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error {
|
func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error {
|
||||||
fs, ok := a.(*sshListHostMapFlags)
|
fs, ok := a.(*sshListHostMapFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -505,7 +505,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
|
func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
err := w.WriteLine("No path to write profile provided")
|
err := w.WriteLine("No path to write profile provided")
|
||||||
return err
|
return err
|
||||||
@@ -527,11 +527,11 @@ func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
return w.WriteLine(fmt.Sprintf("%s", ifce.version))
|
return w.WriteLine(fmt.Sprintf("%s", ifce.version))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine("No vpn address was provided")
|
return w.WriteLine("No vpn address was provided")
|
||||||
}
|
}
|
||||||
@@ -553,7 +553,7 @@ func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter
|
|||||||
return json.NewEncoder(w.GetWriter()).Encode(cm)
|
return json.NewEncoder(w.GetWriter()).Encode(cm)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
flags, ok := fs.(*sshCloseTunnelFlags)
|
flags, ok := fs.(*sshCloseTunnelFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -593,7 +593,7 @@ func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
|
|||||||
return w.WriteLine("Closed")
|
return w.WriteLine("Closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
flags, ok := fs.(*sshCreateTunnelFlags)
|
flags, ok := fs.(*sshCreateTunnelFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -638,7 +638,7 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
|
|||||||
return w.WriteLine("Created")
|
return w.WriteLine("Created")
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
flags, ok := fs.(*sshChangeRemoteFlags)
|
flags, ok := fs.(*sshChangeRemoteFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -675,7 +675,7 @@ func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
|
|||||||
return w.WriteLine("Changed")
|
return w.WriteLine("Changed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
|
func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine("No path to write profile provided")
|
return w.WriteLine("No path to write profile provided")
|
||||||
}
|
}
|
||||||
@@ -696,7 +696,7 @@ func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
|
func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
rate := runtime.SetMutexProfileFraction(-1)
|
rate := runtime.SetMutexProfileFraction(-1)
|
||||||
return w.WriteLine(fmt.Sprintf("Current value: %d", rate))
|
return w.WriteLine(fmt.Sprintf("Current value: %d", rate))
|
||||||
@@ -711,7 +711,7 @@ func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
|
|||||||
return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
|
return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
|
func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine("No path to write profile provided")
|
return w.WriteLine("No path to write profile provided")
|
||||||
}
|
}
|
||||||
@@ -735,7 +735,7 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
|
|||||||
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
|
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||||
}
|
}
|
||||||
@@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) erro
|
|||||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
||||||
}
|
}
|
||||||
@@ -767,7 +767,7 @@ func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) err
|
|||||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
args, ok := fs.(*sshPrintCertFlags)
|
args, ok := fs.(*sshPrintCertFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -822,7 +822,7 @@ func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) erro
|
|||||||
return w.WriteLine(cert.String())
|
return w.WriteLine(cert.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
args, ok := fs.(*sshPrintTunnelFlags)
|
args, ok := fs.(*sshPrintTunnelFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
|
w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
|
||||||
@@ -919,7 +919,7 @@ func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||||
args, ok := fs.(*sshPrintTunnelFlags)
|
args, ok := fs.(*sshPrintTunnelFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -951,7 +951,7 @@ func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
|
|||||||
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
|
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error {
|
func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
|
||||||
|
|
||||||
data := struct {
|
data := struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
// CommandFlags is a function called before help or command execution to parse command line flags
|
// CommandFlags is a function called before help or command execution to parse command line flags
|
||||||
// It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
|
// It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
|
||||||
type CommandFlags func() (*flag.FlagSet, any)
|
type CommandFlags func() (*flag.FlagSet, interface{})
|
||||||
|
|
||||||
// CommandCallback is the function called when your command should execute.
|
// CommandCallback is the function called when your command should execute.
|
||||||
// fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
|
// fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
|
||||||
@@ -21,7 +21,7 @@ type CommandFlags func() (*flag.FlagSet, any)
|
|||||||
// w is the writer to use when sending messages back to the client.
|
// w is the writer to use when sending messages back to the client.
|
||||||
// If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
|
// If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
|
||||||
// where appropriate
|
// where appropriate
|
||||||
type CommandCallback func(fs any, a []string, w StringWriter) error
|
type CommandCallback func(fs interface{}, a []string, w StringWriter) error
|
||||||
|
|
||||||
type Command struct {
|
type Command struct {
|
||||||
Name string
|
Name string
|
||||||
@@ -34,7 +34,7 @@ type Command struct {
|
|||||||
func execCommand(c *Command, args []string, w StringWriter) error {
|
func execCommand(c *Command, args []string, w StringWriter) error {
|
||||||
var (
|
var (
|
||||||
fl *flag.FlagSet
|
fl *flag.FlagSet
|
||||||
fs any
|
fs interface{}
|
||||||
)
|
)
|
||||||
|
|
||||||
if c.Flags != nil {
|
if c.Flags != nil {
|
||||||
@@ -85,7 +85,7 @@ func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
|
|||||||
|
|
||||||
func matchCommand(c *radix.Tree, cmd string) []string {
|
func matchCommand(c *radix.Tree, cmd string) []string {
|
||||||
cmds := make([]string, 0)
|
cmds := make([]string, 0)
|
||||||
c.WalkPrefix(cmd, func(found string, v any) bool {
|
c.WalkPrefix(cmd, func(found string, v interface{}) bool {
|
||||||
cmds = append(cmds, found)
|
cmds = append(cmds, found)
|
||||||
return false
|
return false
|
||||||
})
|
})
|
||||||
@@ -95,7 +95,7 @@ func matchCommand(c *radix.Tree, cmd string) []string {
|
|||||||
|
|
||||||
func allCommands(c *radix.Tree) []*Command {
|
func allCommands(c *radix.Tree) []*Command {
|
||||||
cmds := make([]*Command, 0)
|
cmds := make([]*Command, 0)
|
||||||
c.WalkPrefix("", func(found string, v any) bool {
|
c.WalkPrefix("", func(found string, v interface{}) bool {
|
||||||
cmd, ok := v.(*Command)
|
cmd, ok := v.(*Command)
|
||||||
if ok {
|
if ok {
|
||||||
cmds = append(cmds, cmd)
|
cmds = append(cmds, cmd)
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
|
|||||||
s.RegisterCommand(&Command{
|
s.RegisterCommand(&Command{
|
||||||
Name: "help",
|
Name: "help",
|
||||||
ShortDescription: "prints available commands or help <command> for specific usage info",
|
ShortDescription: "prints available commands or help <command> for specific usage info",
|
||||||
Callback: func(a any, args []string, w StringWriter) error {
|
Callback: func(a interface{}, args []string, w StringWriter) error {
|
||||||
return helpCallback(s.commands, args, w)
|
return helpCallback(s.commands, args, w)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ import (
|
|||||||
"github.com/armon/go-radix"
|
"github.com/armon/go-radix"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/term"
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type session struct {
|
type session struct {
|
||||||
l *logrus.Entry
|
l *logrus.Entry
|
||||||
c *ssh.ServerConn
|
c *ssh.ServerConn
|
||||||
term *term.Terminal
|
term *terminal.Terminal
|
||||||
commands *radix.Tree
|
commands *radix.Tree
|
||||||
exitChan chan bool
|
exitChan chan bool
|
||||||
}
|
}
|
||||||
@@ -31,7 +31,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New
|
|||||||
s.commands.Insert("logout", &Command{
|
s.commands.Insert("logout", &Command{
|
||||||
Name: "logout",
|
Name: "logout",
|
||||||
ShortDescription: "Ends the current session",
|
ShortDescription: "Ends the current session",
|
||||||
Callback: func(a any, args []string, w StringWriter) error {
|
Callback: func(a interface{}, args []string, w StringWriter) error {
|
||||||
s.Close()
|
s.Close()
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -106,8 +106,8 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
|
func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
|
||||||
term := term.NewTerminal(channel, s.c.User()+"@nebula > ")
|
term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
|
||||||
term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
|
term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
|
||||||
// key 9 is tab
|
// key 9 is tab
|
||||||
if key == 9 {
|
if key == 9 {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
|
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
|
||||||
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
|
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
|
||||||
func AssertDeepCopyEqual(t *testing.T, a any, b any) {
|
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
|
||||||
v1 := reflect.ValueOf(a)
|
v1 := reflect.ValueOf(a)
|
||||||
v2 := reflect.ValueOf(b)
|
v2 := reflect.ValueOf(b)
|
||||||
|
|
||||||
|
|||||||
@@ -4,14 +4,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type NoopTun struct{}
|
type NoopTun struct{}
|
||||||
|
|
||||||
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
|
||||||
return routing.Gateways{}
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) Activate() error {
|
func (NoopTun) Activate() error {
|
||||||
|
|||||||
@@ -84,10 +84,6 @@ func (u *StdConn) SetSendBuffer(n int) error {
|
|||||||
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
|
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) SetSoMark(mark int) error {
|
|
||||||
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) GetRecvBuffer() (int, error) {
|
func (u *StdConn) GetRecvBuffer() (int, error) {
|
||||||
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
|
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
|
||||||
}
|
}
|
||||||
@@ -96,10 +92,6 @@ func (u *StdConn) GetSendBuffer() (int, error) {
|
|||||||
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
|
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) GetSoMark() (int, error) {
|
|
||||||
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
sa, err := unix.Getsockname(u.sysFd)
|
sa, err := unix.Getsockname(u.sysFd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -278,22 +270,6 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
u.l.WithError(err).Error("Failed to set listen.write_buffer")
|
u.l.WithError(err).Error("Failed to set listen.write_buffer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b = c.GetInt("listen.so_mark", 0)
|
|
||||||
s, err := u.GetSoMark()
|
|
||||||
if b > 0 || (err == nil && s != 0) {
|
|
||||||
err := u.SetSoMark(b)
|
|
||||||
if err == nil {
|
|
||||||
s, err := u.GetSoMark()
|
|
||||||
if err == nil {
|
|
||||||
u.l.WithField("mark", s).Info("listen.so_mark was set")
|
|
||||||
} else {
|
|
||||||
u.l.WithError(err).Warn("Failed to get listen.so_mark")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ import (
|
|||||||
|
|
||||||
type ContextualError struct {
|
type ContextualError struct {
|
||||||
RealError error
|
RealError error
|
||||||
Fields map[string]any
|
Fields map[string]interface{}
|
||||||
Context string
|
Context string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextualError(msg string, fields map[string]any, realError error) *ContextualError {
|
func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError {
|
||||||
return &ContextualError{Context: msg, Fields: fields, RealError: realError}
|
return &ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m map[string]interface{}
|
||||||
|
|
||||||
type TestLogWriter struct {
|
type TestLogWriter struct {
|
||||||
Logs []string
|
Logs []string
|
||||||
|
|||||||
Reference in New Issue
Block a user