diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index e0d41ae..20a39cf 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Install goimports diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 31987db..392f71b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Build @@ -37,7 +37,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Build @@ -70,7 +70,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Import certificates diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index 2b5e6e9..de582de 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -27,6 +27,9 @@ jobs: go-version-file: 'go.mod' check-latest: true + - name: add hashicorp source + run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list + - name: install vagrant run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index 08b2d3d..b3b847e 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: build diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh index 76cf72f..1c1e3c5 100755 --- a/.github/workflows/smoke/smoke-vagrant.sh +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test docker run --name host2 --rm "$CONTAINER" -config host2.yml -test vagrant up -vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" +vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" & +vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 15 # grab tcpdump pcaps for debugging @@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap & # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap & -docker exec host2 ncat -nklv 0.0.0.0 2000 & -vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & +#docker exec host2 ncat -nklv 0.0.0.0 2000 & +#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" & @@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1 # Should fail because not allowed by host3 inbound firewall ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 -set +x -echo -echo " *** Testing ncat from host2" -echo -set -x +#set +x +#echo +#echo " *** Testing ncat from host2" +#echo +#set -x # Should fail because not allowed by host3 inbound firewall #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 @@ -82,18 +82,18 @@ echo echo " *** Testing ping from host3" echo set -x -vagrant ssh -c "ping -c1 192.168.100.1" -vagrant ssh -c "ping -c1 192.168.100.2" +vagrant ssh -c "ping -c1 192.168.100.1" -- -T +vagrant ssh -c "ping -c1 192.168.100.2" -- -T -set +x -echo -echo " *** Testing ncat from host3" -echo -set -x +#set +x +#echo +#echo " *** Testing ncat from host3" +#echo +#set -x #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000" #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2 -vagrant ssh -c "sudo xargs kill math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) @@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { - // Combine the masked bytes of the "mask" IP with the unmasked bytes - // of the overlay IP - if c.ipNet.Addr().Is4() { - return c.apply4(ip) - } - return c.apply6(ip) -} - -func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { - //TODO: IPV6-WORK this can be less crappy +func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort { + // Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) mask := binary.BigEndian.Uint32(maskb[:]) b := c.mask.Addr().As4() - maskIp := binary.BigEndian.Uint32(b[:]) + maskAddr := binary.BigEndian.Uint32(b[:]) - b = ip.As4() - intIp := binary.BigEndian.Uint32(b[:]) + b = addr.As4() + intAddr := binary.BigEndian.Uint32(b[:]) - return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} + return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port} } -func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { - //TODO: IPV6-WORK - panic("Can not calculate ipv6 remote addresses") +func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort { + mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + maskAddr := c.mask.Addr().As16() + calcAddr := addr.As16() + + ap := V6AddrPort{Port: c.port} + + maskb := binary.BigEndian.Uint64(mask[:8]) + maskAddrb := binary.BigEndian.Uint64(maskAddr[:8]) + calcAddrb := binary.BigEndian.Uint64(calcAddr[:8]) + ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb) + + maskb = binary.BigEndian.Uint64(mask[8:]) + maskAddrb = binary.BigEndian.Uint64(maskAddr[8:]) + calcAddrb = binary.BigEndian.Uint64(calcAddr[8:]) + ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb) + + return &ap } func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { @@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } - //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here - entry, err := newCalculatedRemotesListFromConfig(rawValue) + entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } @@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu return calculatedRemotes, nil } -func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { +func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) { rawList, ok := raw.([]any) if !ok { return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) @@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { var l []*calculatedRemote for _, e := range rawList { - c, err := newCalculatedRemotesEntryFromConfig(e) + c, err := newCalculatedRemotesEntryFromConfig(cidr, e) if err != nil { return nil, fmt.Errorf("calculated_remotes entry: %w", err) } @@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { return l, nil } -func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { +func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) { rawMap, ok := raw.(map[any]any) if !ok { return nil, fmt.Errorf("invalid type: %T", raw) @@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } - return newCalculatedRemote(maskCidr, port) + return newCalculatedRemote(cidr, maskCidr, port) } diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 6ff1cb0..066213e 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -9,10 +9,9 @@ import ( ) func TestCalculatedRemoteApply(t *testing.T) { - ipNet, err := netip.ParsePrefix("192.168.1.0/24") - require.NoError(t, err) - - c, err := newCalculatedRemote(ipNet, 4242) + // Test v4 addresses + ipNet := netip.MustParsePrefix("192.168.1.0/24") + c, err := newCalculatedRemote(ipNet, ipNet, 4242) require.NoError(t, err) input, err := netip.ParseAddr("10.0.10.182") @@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) { expected, err := netip.ParseAddr("192.168.1.182") assert.NoError(t, err) - assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) + assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) + + // Test v6 addresses + ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) + + // Test v6 addresses part 2 + ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) + + // Test v6 addresses part 2 + ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) +} + +func Test_newCalculatedRemote(t *testing.T) { + c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242) + require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128") + require.Nil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242) + require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32") + require.Nil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242) + require.NoError(t, err) + require.NotNil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242) + require.NoError(t, err) + require.NotNil(t, c) } diff --git a/cert/Makefile b/cert/Makefile index 28170b6..311afc2 100644 --- a/cert/Makefile +++ b/cert/Makefile @@ -1,7 +1,7 @@ GO111MODULE = on export GO111MODULE -cert.pb.go: cert.proto .FORCE +cert_v1.pb.go: cert_v1.proto .FORCE go build google.golang.org/protobuf/cmd/protoc-gen-go PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $< rm protoc-gen-go diff --git a/cert/README.md b/cert/README.md index ae19a28..1e27a6b 100644 --- a/cert/README.md +++ b/cert/README.md @@ -2,14 +2,25 @@ This is a library for interacting with `nebula` style certificates and authorities. -A `protobuf` definition of the certificate format is also included +There are now 2 versions of `nebula` certificates: -### Compiling the protobuf definition +## v1 -Make sure you have `protoc` installed. +This version is deprecated. + +A `protobuf` definition of the certificate format is included at `cert_v1.proto` + +To compile the definition you will need `protoc` installed. To compile for `go` with the same version of protobuf specified in go.mod: ```bash -make +make proto ``` + +## v2 + +This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate +future certificate changes better than v1. + +`cert_v2.asn1` defines the wire format and can be used to compile marshalers. \ No newline at end of file diff --git a/cert/asn1.go b/cert/asn1.go new file mode 100644 index 0000000..6bf6a8d --- /dev/null +++ b/cert/asn1.go @@ -0,0 +1,52 @@ +package cert + +import ( + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value +// https://github.com/golang/go/issues/64811#issuecomment-1944446920 +func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] > 0 + return true + } + + return false +} + +// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value +// Similar issue as with readOptionalASN1Boolean +func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] + return true + } + + return false +} diff --git a/cert/ca.go b/cert/ca.go deleted file mode 100644 index 0ffbd87..0000000 --- a/cert/ca.go +++ /dev/null @@ -1,140 +0,0 @@ -package cert - -import ( - "errors" - "fmt" - "strings" - "time" -) - -type NebulaCAPool struct { - CAs map[string]*NebulaCertificate - certBlocklist map[string]struct{} -} - -// NewCAPool creates a CAPool -func NewCAPool() *NebulaCAPool { - ca := NebulaCAPool{ - CAs: make(map[string]*NebulaCertificate), - certBlocklist: make(map[string]struct{}), - } - - return &ca -} - -// NewCAPoolFromBytes will create a new CA pool from the provided -// input bytes, which must be a PEM-encoded set of nebula certificates. -// If the pool contains any expired certificates, an ErrExpired will be -// returned along with the pool. The caller must handle any such errors. -func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) { - pool := NewCAPool() - var err error - var expired bool - for { - caPEMs, err = pool.AddCACertificate(caPEMs) - if errors.Is(err, ErrExpired) { - expired = true - err = nil - } - if err != nil { - return nil, err - } - if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { - break - } - } - - if expired { - return pool, ErrExpired - } - - return pool, nil -} - -// AddCACertificate verifies a Nebula CA certificate and adds it to the pool -// Only the first pem encoded object will be consumed, any remaining bytes are returned. -// Parsed certificates will be verified and must be a CA -func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) { - c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes) - if err != nil { - return pemBytes, err - } - - if !c.Details.IsCA { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA) - } - - if !c.CheckSignature(c.Details.PublicKey) { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned) - } - - sum, err := c.Sha256Sum() - if err != nil { - return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name) - } - - ncp.CAs[sum] = c - if c.Expired(time.Now()) { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired) - } - - return pemBytes, nil -} - -// BlocklistFingerprint adds a cert fingerprint to the blocklist -func (ncp *NebulaCAPool) BlocklistFingerprint(f string) { - ncp.certBlocklist[f] = struct{}{} -} - -// ResetCertBlocklist removes all previously blocklisted cert fingerprints -func (ncp *NebulaCAPool) ResetCertBlocklist() { - ncp.certBlocklist = make(map[string]struct{}) -} - -// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated -// automatically if you manually change any fields in the NebulaCertificate. -func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool { - return ncp.isBlocklistedWithCache(c, false) -} - -// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted -func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool { - h, err := c.sha256SumWithCache(useCache) - if err != nil { - return true - } - - if _, ok := ncp.certBlocklist[h]; ok { - return true - } - - return false -} - -// GetCAForCert attempts to return the signing certificate for the provided certificate. -// No signature validation is performed -func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) { - if c.Details.Issuer == "" { - return nil, fmt.Errorf("no issuer in certificate") - } - - signer, ok := ncp.CAs[c.Details.Issuer] - if ok { - return signer, nil - } - - return nil, fmt.Errorf("could not find ca for the certificate") -} - -// GetFingerprints returns an array of trusted CA fingerprints -func (ncp *NebulaCAPool) GetFingerprints() []string { - fp := make([]string, len(ncp.CAs)) - - i := 0 - for k := range ncp.CAs { - fp[i] = k - i++ - } - - return fp -} diff --git a/cert/ca_pool.go b/cert/ca_pool.go new file mode 100644 index 0000000..2bf480f --- /dev/null +++ b/cert/ca_pool.go @@ -0,0 +1,296 @@ +package cert + +import ( + "errors" + "fmt" + "net/netip" + "slices" + "strings" + "time" +) + +type CAPool struct { + CAs map[string]*CachedCertificate + certBlocklist map[string]struct{} +} + +// NewCAPool creates an empty CAPool +func NewCAPool() *CAPool { + ca := CAPool{ + CAs: make(map[string]*CachedCertificate), + certBlocklist: make(map[string]struct{}), + } + + return &ca +} + +// NewCAPoolFromPEM will create a new CA pool from the provided +// input bytes, which must be a PEM-encoded set of nebula certificates. +// If the pool contains any expired certificates, an ErrExpired will be +// returned along with the pool. The caller must handle any such errors. +func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { + pool := NewCAPool() + var err error + var expired bool + for { + caPEMs, err = pool.AddCAFromPEM(caPEMs) + if errors.Is(err, ErrExpired) { + expired = true + err = nil + } + if err != nil { + return nil, err + } + if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { + break + } + } + + if expired { + return pool, ErrExpired + } + + return pool, nil +} + +// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool. +// Only the first pem encoded object will be consumed, any remaining bytes are returned. +// Parsed certificates will be verified and must be a CA +func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) { + c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes) + if err != nil { + return pemBytes, err + } + + err = ncp.AddCA(c) + if err != nil { + return pemBytes, err + } + + return pemBytes, nil +} + +// AddCA verifies a Nebula CA certificate and adds it to the pool. +func (ncp *CAPool) AddCA(c Certificate) error { + if !c.IsCA() { + return fmt.Errorf("%s: %w", c.Name(), ErrNotCA) + } + + if !c.CheckSignature(c.PublicKey()) { + return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned) + } + + sum, err := c.Fingerprint() + if err != nil { + return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name()) + } + + cc := &CachedCertificate{ + Certificate: c, + Fingerprint: sum, + InvertedGroups: make(map[string]struct{}), + } + + for _, g := range c.Groups() { + cc.InvertedGroups[g] = struct{}{} + } + + ncp.CAs[sum] = cc + + if c.Expired(time.Now()) { + return fmt.Errorf("%s: %w", c.Name(), ErrExpired) + } + + return nil +} + +// BlocklistFingerprint adds a cert fingerprint to the blocklist +func (ncp *CAPool) BlocklistFingerprint(f string) { + ncp.certBlocklist[f] = struct{}{} +} + +// ResetCertBlocklist removes all previously blocklisted cert fingerprints +func (ncp *CAPool) ResetCertBlocklist() { + ncp.certBlocklist = make(map[string]struct{}) +} + +// IsBlocklisted tests the provided fingerprint against the pools blocklist. +// Returns true if the fingerprint is blocked. +func (ncp *CAPool) IsBlocklisted(fingerprint string) bool { + if _, ok := ncp.certBlocklist[fingerprint]; ok { + return true + } + + return false +} + +// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool. +// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts +// to increase performance. +func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) { + if c == nil { + return nil, fmt.Errorf("no certificate") + } + fp, err := c.Fingerprint() + if err != nil { + return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err) + } + + signer, err := ncp.verify(c, now, fp, "") + if err != nil { + return nil, err + } + + cc := CachedCertificate{ + Certificate: c, + InvertedGroups: make(map[string]struct{}), + Fingerprint: fp, + signerFingerprint: signer.Fingerprint, + } + + for _, g := range c.Groups() { + cc.InvertedGroups[g] = struct{}{} + } + + return &cc, nil +} + +// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and +// is a cheaper operation to perform as a result. +func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { + _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint) + return err +} + +func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) { + if ncp.IsBlocklisted(certFp) { + return nil, ErrBlockListed + } + + signer, err := ncp.GetCAForCert(c) + if err != nil { + return nil, err + } + + if signer.Certificate.Expired(now) { + return nil, ErrRootExpired + } + + if c.Expired(now) { + return nil, ErrExpired + } + + // If we are checking a cached certificate then we can bail early here + // Either the root is no longer trusted or everything is fine + if len(signerFp) > 0 { + if signerFp != signer.Fingerprint { + return nil, ErrFingerprintMismatch + } + return signer, nil + } + if !c.CheckSignature(signer.Certificate.PublicKey()) { + return nil, ErrSignatureMismatch + } + + err = CheckCAConstraints(signer.Certificate, c) + if err != nil { + return nil, err + } + + return signer, nil +} + +// GetCAForCert attempts to return the signing certificate for the provided certificate. +// No signature validation is performed +func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) { + issuer := c.Issuer() + if issuer == "" { + return nil, fmt.Errorf("no issuer in certificate") + } + + signer, ok := ncp.CAs[issuer] + if ok { + return signer, nil + } + + return nil, ErrCaNotFound +} + +// GetFingerprints returns an array of trusted CA fingerprints +func (ncp *CAPool) GetFingerprints() []string { + fp := make([]string, len(ncp.CAs)) + + i := 0 + for k := range ncp.CAs { + fp[i] = k + i++ + } + + return fp +} + +// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate. +func CheckCAConstraints(signer Certificate, sub Certificate) error { + return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks()) +} + +// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested. +func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error { + // Make sure this cert isn't valid after the root + if notAfter.After(signer.NotAfter()) { + return fmt.Errorf("certificate expires after signing certificate") + } + + // Make sure this cert wasn't valid before the root + if notBefore.Before(signer.NotBefore()) { + return fmt.Errorf("certificate is valid before the signing certificate") + } + + // If the signer has a limited set of groups make sure the cert only contains a subset + signerGroups := signer.Groups() + if len(signerGroups) > 0 { + for _, g := range groups { + if !slices.Contains(signerGroups, g) { + return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) + } + } + } + + // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset + signingNetworks := signer.Networks() + if len(signingNetworks) > 0 { + for _, certNetwork := range networks { + found := false + for _, signingNetwork := range signingNetworks { + if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String()) + } + } + } + + // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset + signingUnsafeNetworks := signer.UnsafeNetworks() + if len(signingUnsafeNetworks) > 0 { + for _, certUnsafeNetwork := range unsafeNetworks { + found := false + for _, caNetwork := range signingUnsafeNetworks { + if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String()) + } + } + } + + return nil +} diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go new file mode 100644 index 0000000..f03b2ba --- /dev/null +++ b/cert/ca_pool_test.go @@ -0,0 +1,559 @@ +package cert + +import ( + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewCAPoolFromBytes(t *testing.T) { + noNewLines := ` +# Current provisional, Remove once everything moves over to the real root. +-----BEGIN NEBULA CERTIFICATE----- +Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ +PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf +2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== +-----END NEBULA CERTIFICATE----- +# root-ca01 +-----BEGIN NEBULA CERTIFICATE----- +CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br +BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye +rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== +-----END NEBULA CERTIFICATE----- +` + + withNewLines := ` +# Current provisional, Remove once everything moves over to the real root. + +-----BEGIN NEBULA CERTIFICATE----- +Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ +PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf +2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== +-----END NEBULA CERTIFICATE----- + +# root-ca01 + + +-----BEGIN NEBULA CERTIFICATE----- +CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br +BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye +rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== +-----END NEBULA CERTIFICATE----- + +` + + expired := ` +# expired certificate +-----BEGIN NEBULA CERTIFICATE----- +CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA +7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8 +Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0= +-----END NEBULA CERTIFICATE----- +` + + p256 := ` +# p256 certificate +-----BEGIN NEBULA CERTIFICATE----- +CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp +k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe ++0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq +75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA== +-----END NEBULA CERTIFICATE----- +` + + rootCA := certificateV1{ + details: detailsV1{ + name: "nebula root ca", + }, + } + + rootCA01 := certificateV1{ + details: detailsV1{ + name: "nebula root ca 01", + }, + } + + rootCAP256 := certificateV1{ + details: detailsV1{ + name: "nebula P256 test", + }, + } + + p, err := NewCAPoolFromPEM([]byte(noNewLines)) + assert.Nil(t, err) + assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) + + pp, err := NewCAPoolFromPEM([]byte(withNewLines)) + assert.Nil(t, err) + assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) + + // expired cert, no valid certs + ppp, err := NewCAPoolFromPEM([]byte(expired)) + assert.Equal(t, ErrExpired, err) + assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") + + // expired cert, with valid certs + pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) + assert.Equal(t, ErrExpired, err) + assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) + assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") + assert.Equal(t, len(pppp.CAs), 3) + + ppppp, err := NewCAPoolFromPEM([]byte(p256)) + assert.Nil(t, err) + assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) + assert.Equal(t, len(ppppp.CAs), 1) +} + +func TestCertificateV1_Verify(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_VerifyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_Verify_IPs(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_Verify_Subnets(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_VerifyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify_IPs(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify_Subnets(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} diff --git a/cert/cert.go b/cert/cert.go index a0164f7..4246571 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -1,1029 +1,165 @@ package cert import ( - "bytes" - "crypto/ecdh" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "encoding/json" - "encoding/pem" - "errors" "fmt" - "math" - "math/big" - "net" - "sync/atomic" + "net/netip" "time" - - "golang.org/x/crypto/curve25519" - "google.golang.org/protobuf/proto" ) -const publicKeyLen = 32 +type Version uint8 const ( - CertBanner = "NEBULA CERTIFICATE" - X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" - X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" - EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" - Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" - Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" - - P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" - P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" - EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" - ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" + VersionPre1 Version = 0 + Version1 Version = 1 + Version2 Version = 2 ) -type NebulaCertificate struct { - Details NebulaCertificateDetails - Signature []byte +type Certificate interface { + // Version defines the underlying certificate structure and wire protocol version + // Version1 certificates are ipv4 only and uses protobuf serialization + // Version2 certificates are ipv4 or ipv6 and uses asn.1 serialization + Version() Version - // the cached hex string of the calculated sha256sum - // for VerifyWithCache - sha256sum atomic.Pointer[string] + // Name is the human-readable name that identifies this certificate. + Name() string - // the cached public key bytes if they were verified as the signer - // for VerifyWithCache - signatureVerified atomic.Pointer[[]byte] + // Networks is a list of ip addresses and network sizes assigned to this certificate. + // If IsCA is true then certificates signed by this CA can only have ip addresses and + // networks that are contained by an entry in this list. + Networks() []netip.Prefix + + // UnsafeNetworks is a list of networks that this host can act as an unsafe router for. + // If IsCA is true then certificates signed by this CA can only have networks that are + // contained by an entry in this list. + UnsafeNetworks() []netip.Prefix + + // Groups is a list of identities that can be used to write more general firewall rule + // definitions. + // If IsCA is true then certificates signed by this CA can only use groups that are + // in this list. + Groups() []string + + // IsCA signifies if this is a certificate authority (true) or a host certificate (false). + // It is invalid to use a CA certificate as a host certificate. + IsCA() bool + + // NotBefore is the time at which this certificate becomes valid. + // If IsCA is true then certificate signed by this CA can not have a time before this. + NotBefore() time.Time + + // NotAfter is the time at which this certificate becomes invalid. + // If IsCA is true then certificate signed by this CA can not have a time after this. + NotAfter() time.Time + + // Issuer is the fingerprint of the CA that signed this certificate. + // If IsCA is true then this will be empty. + Issuer() string + + // PublicKey is the raw bytes to be used in asymmetric cryptographic operations. + PublicKey() []byte + + // Curve identifies which curve was used for the PublicKey and Signature. + Curve() Curve + + // Signature is the cryptographic seal for all the details of this certificate. + // CheckSignature can be used to verify that the details of this certificate are valid. + Signature() []byte + + // CheckSignature will check that the certificate Signature() matches the + // computed signature. A true result means this certificate has not been tampered with. + CheckSignature(signingPublicKey []byte) bool + + // Fingerprint returns the hex encoded sha256 sum of the certificate. + // This acts as a unique fingerprint and can be used to blocklist certificates. + Fingerprint() (string, error) + + // Expired tests if the certificate is valid for the provided time. + Expired(t time.Time) bool + + // VerifyPrivateKey returns an error if the private key is not a pair with the certificates public key. + VerifyPrivateKey(curve Curve, privateKey []byte) error + + // Marshal will return the byte representation of this certificate + // This is primarily the format transmitted on the wire. + Marshal() ([]byte, error) + + // MarshalForHandshakes prepares the bytes needed to use directly in a handshake + MarshalForHandshakes() ([]byte, error) + + // MarshalPEM will return a PEM encoded representation of this certificate + // This is primarily the format stored on disk + MarshalPEM() ([]byte, error) + + // MarshalJSON will return the json representation of this certificate + MarshalJSON() ([]byte, error) + + // String will return a human-readable representation of this certificate + String() string + + // Copy creates a copy of the certificate + Copy() Certificate } -type NebulaCertificateDetails struct { - Name string - Ips []*net.IPNet - Subnets []*net.IPNet - Groups []string - NotBefore time.Time - NotAfter time.Time - PublicKey []byte - IsCA bool - Issuer string - - // Map of groups for faster lookup - InvertedGroups map[string]struct{} - - Curve Curve +// CachedCertificate represents a verified certificate with some cached fields to improve +// performance. +type CachedCertificate struct { + Certificate Certificate + InvertedGroups map[string]struct{} + Fingerprint string + signerFingerprint string } -type NebulaEncryptedData struct { - EncryptionMetadata NebulaEncryptionMetadata - Ciphertext []byte +func (cc *CachedCertificate) String() string { + return cc.Certificate.String() } -type NebulaEncryptionMetadata struct { - EncryptionAlgorithm string - Argon2Parameters Argon2Parameters -} - -type m map[string]interface{} - -// Returned if we try to unmarshal an encrypted private key without a passphrase -var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") - -// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert -func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { - if len(b) == 0 { - return nil, fmt.Errorf("nil byte array") +// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake. +// Handshakes save space by placing the peers public key in a different part of the packet, we have to +// reassemble the actual certificate structure with that in mind. +func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) { + if publicKey == nil { + return nil, ErrNoPeerStaticKey } - var rc RawNebulaCertificate - err := proto.Unmarshal(b, &rc) + + if rawCertBytes == nil { + return nil, ErrNoPayload + } + + c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve) + if err != nil { + return nil, fmt.Errorf("error unmarshaling cert: %w", err) + } + + cc, err := caPool.VerifyCertificate(time.Now(), c) + if err != nil { + return nil, fmt.Errorf("certificate validation failed: %w", err) + } + + return cc, nil +} + +func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) { + var c Certificate + var err error + + switch v { + // Implementations must ensure the result is a valid cert! + case VersionPre1, Version1: + c, err = unmarshalCertificateV1(b, publicKey) + case Version2: + c, err = unmarshalCertificateV2(b, publicKey, curve) + default: + //TODO: CERT-V2 make a static var + return nil, fmt.Errorf("unknown certificate version %d", v) + } + if err != nil { return nil, err } - if rc.Details == nil { - return nil, fmt.Errorf("encoded Details was nil") + if c.Curve() != curve { + return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String()) } - if len(rc.Details.Ips)%2 != 0 { - return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") - } - - if len(rc.Details.Subnets)%2 != 0 { - return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") - } - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: rc.Details.Name, - Groups: make([]string, len(rc.Details.Groups)), - Ips: make([]*net.IPNet, len(rc.Details.Ips)/2), - Subnets: make([]*net.IPNet, len(rc.Details.Subnets)/2), - NotBefore: time.Unix(rc.Details.NotBefore, 0), - NotAfter: time.Unix(rc.Details.NotAfter, 0), - PublicKey: make([]byte, len(rc.Details.PublicKey)), - IsCA: rc.Details.IsCA, - InvertedGroups: make(map[string]struct{}), - Curve: rc.Details.Curve, - }, - Signature: make([]byte, len(rc.Signature)), - } - - copy(nc.Signature, rc.Signature) - copy(nc.Details.Groups, rc.Details.Groups) - nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer) - - if len(rc.Details.PublicKey) < publicKeyLen { - return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) - } - copy(nc.Details.PublicKey, rc.Details.PublicKey) - - for i, rawIp := range rc.Details.Ips { - if i%2 == 0 { - nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)} - } else { - nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp)) - } - } - - for i, rawIp := range rc.Details.Subnets { - if i%2 == 0 { - nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)} - } else { - nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp)) - } - } - - for _, g := range rc.Details.Groups { - nc.Details.InvertedGroups[g] = struct{}{} - } - - return &nc, nil -} - -// UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data -// or an error on failure -func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) { - p, r := pem.Decode(b) - if p == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if p.Type != CertBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner") - } - nc, err := UnmarshalNebulaCertificate(p.Bytes) - return nc, r, err -} - -func MarshalPrivateKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b}) - default: - return nil - } -} - -func MarshalSigningPrivateKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b}) - default: - return nil - } -} - -// MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key -func MarshalX25519PrivateKey(b []byte) []byte { - return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) -} - -// MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key -func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte { - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key}) -} - -func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var expectedLen int - var curve Curve - switch k.Type { - case X25519PrivateKeyBanner: - expectedLen = 32 - curve = Curve_CURVE25519 - case P256PrivateKeyBanner: - expectedLen = 32 - curve = Curve_P256 - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner") - } - if len(k.Bytes) != expectedLen { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve) - } - return k.Bytes, r, curve, nil -} - -func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var curve Curve - switch k.Type { - case EncryptedEd25519PrivateKeyBanner: - return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted - case EncryptedECDSAP256PrivateKeyBanner: - return nil, nil, Curve_P256, ErrPrivateKeyEncrypted - case Ed25519PrivateKeyBanner: - curve = Curve_CURVE25519 - if len(k.Bytes) != ed25519.PrivateKeySize { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize) - } - case ECDSAP256PrivateKeyBanner: - curve = Curve_P256 - if len(k.Bytes) != 32 { - return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") - } - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner") - } - return k.Bytes, r, curve, nil -} - -// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key -func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { - ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) - if err != nil { - return nil, err - } - - b, err = proto.Marshal(&RawNebulaEncryptedData{ - EncryptionMetadata: &RawNebulaEncryptionMetadata{ - EncryptionAlgorithm: "AES-256-GCM", - Argon2Parameters: &RawNebulaArgon2Parameters{ - Version: kdfParams.version, - Memory: kdfParams.Memory, - Parallelism: uint32(kdfParams.Parallelism), - Iterations: kdfParams.Iterations, - Salt: kdfParams.salt, - }, - }, - Ciphertext: ciphertext, - }) - if err != nil { - return nil, err - } - - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil - default: - return nil, fmt.Errorf("invalid curve: %v", curve) - } -} - -// UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b -// or an error on failure -func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != X25519PrivateKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner") - } - if len(k.Bytes) != publicKeyLen { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key") - } - - return k.Bytes, r, nil -} - -// UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b -// or an error on failure -func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - - if k.Type == EncryptedEd25519PrivateKeyBanner { - return nil, r, ErrPrivateKeyEncrypted - } else if k.Type != Ed25519PrivateKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner") - } - - if len(k.Bytes) != ed25519.PrivateKeySize { - return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") - } - - return k.Bytes, r, nil -} - -// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its -// protobuf-generated struct. -func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { - if len(b) == 0 { - return nil, fmt.Errorf("nil byte array") - } - var rned RawNebulaEncryptedData - err := proto.Unmarshal(b, &rned) - if err != nil { - return nil, err - } - - if rned.EncryptionMetadata == nil { - return nil, fmt.Errorf("encoded EncryptionMetadata was nil") - } - - if rned.EncryptionMetadata.Argon2Parameters == nil { - return nil, fmt.Errorf("encoded Argon2Parameters was nil") - } - - params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) - if err != nil { - return nil, err - } - - ned := NebulaEncryptedData{ - EncryptionMetadata: NebulaEncryptionMetadata{ - EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, - Argon2Parameters: *params, - }, - Ciphertext: rned.Ciphertext, - } - - return &ned, nil -} - -func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { - if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { - return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) - } - if params.Memory <= 0 || params.Memory > math.MaxUint32 { - return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) - } - if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { - return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) - } - if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { - return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) - } - - return &Argon2Parameters{ - version: rune(params.Version), - Memory: uint32(params.Memory), - Parallelism: uint8(params.Parallelism), - Iterations: uint32(params.Iterations), - salt: params.Salt, - }, nil - -} - -// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with -// the given passphrase, returning any other bytes b or an error on failure -func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) { - var curve Curve - - k, r := pem.Decode(b) - if k == nil { - return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - - switch k.Type { - case EncryptedEd25519PrivateKeyBanner: - curve = Curve_CURVE25519 - case EncryptedECDSAP256PrivateKeyBanner: - curve = Curve_P256 - default: - return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") - } - - ned, err := UnmarshalNebulaEncryptedData(k.Bytes) - if err != nil { - return curve, nil, r, err - } - - var bytes []byte - switch ned.EncryptionMetadata.EncryptionAlgorithm { - case "AES-256-GCM": - bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) - if err != nil { - return curve, nil, r, err - } - default: - return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) - } - - switch curve { - case Curve_CURVE25519: - if len(bytes) != ed25519.PrivateKeySize { - return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize) - } - case Curve_P256: - if len(bytes) != 32 { - return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") - } - } - - return curve, bytes, r, nil -} - -func MarshalPublicKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) - default: - return nil - } -} - -// MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key -func MarshalX25519PublicKey(b []byte) []byte { - return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) -} - -// MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key -func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte { - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key}) -} - -func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var expectedLen int - var curve Curve - switch k.Type { - case X25519PublicKeyBanner: - expectedLen = 32 - curve = Curve_CURVE25519 - case P256PublicKeyBanner: - // Uncompressed - expectedLen = 65 - curve = Curve_P256 - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner") - } - if len(k.Bytes) != expectedLen { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve) - } - return k.Bytes, r, curve, nil -} - -// UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b -// or an error on failure -func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != X25519PublicKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner") - } - if len(k.Bytes) != publicKeyLen { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key") - } - - return k.Bytes, r, nil -} - -// UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b -// or an error on failure -func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != Ed25519PublicKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner") - } - if len(k.Bytes) != ed25519.PublicKeySize { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key") - } - - return k.Bytes, r, nil -} - -// Sign signs a nebula cert with the provided private key -func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error { - if curve != nc.Details.Curve { - return fmt.Errorf("curve in cert and private key supplied don't match") - } - - b, err := proto.Marshal(nc.getRawDetails()) - if err != nil { - return err - } - - var sig []byte - - switch curve { - case Curve_CURVE25519: - signer := ed25519.PrivateKey(key) - sig = ed25519.Sign(signer, b) - case Curve_P256: - signer := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), - } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) - - // We need to hash first for ECDSA - // - https://pkg.go.dev/crypto/ecdsa#SignASN1 - hashed := sha256.Sum256(b) - sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) - if err != nil { - return err - } - default: - return fmt.Errorf("invalid curve: %s", nc.Details.Curve) - } - - nc.Signature = sig - return nil -} - -// CheckSignature verifies the signature against the provided public key -func (nc *NebulaCertificate) CheckSignature(key []byte) bool { - b, err := proto.Marshal(nc.getRawDetails()) - if err != nil { - return false - } - switch nc.Details.Curve { - case Curve_CURVE25519: - return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature) - case Curve_P256: - x, y := elliptic.Unmarshal(elliptic.P256(), key) - pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} - hashed := sha256.Sum256(b) - return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature) - default: - return false - } -} - -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool { - if !useCache { - return nc.CheckSignature(key) - } - - if v := nc.signatureVerified.Load(); v != nil { - return bytes.Equal(*v, key) - } - - verified := nc.CheckSignature(key) - if verified { - keyCopy := make([]byte, len(key)) - copy(keyCopy, key) - nc.signatureVerified.Store(&keyCopy) - } - - return verified -} - -// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false -func (nc *NebulaCertificate) Expired(t time.Time) bool { - return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t) -} - -// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) { - return nc.verify(t, ncp, false) -} - -// VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -// -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *NebulaCAPool) (bool, error) { - return nc.verify(t, ncp, true) -} - -// ResetCache resets the cache used by VerifyWithCache. -func (nc *NebulaCertificate) ResetCache() { - nc.sha256sum.Store(nil) - nc.signatureVerified.Store(nil) -} - -// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -func (nc *NebulaCertificate) verify(t time.Time, ncp *NebulaCAPool, useCache bool) (bool, error) { - if ncp.isBlocklistedWithCache(nc, useCache) { - return false, ErrBlockListed - } - - signer, err := ncp.GetCAForCert(nc) - if err != nil { - return false, err - } - - if signer.Expired(t) { - return false, ErrRootExpired - } - - if nc.Expired(t) { - return false, ErrExpired - } - - if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) { - return false, ErrSignatureMismatch - } - - if err := nc.CheckRootConstrains(signer); err != nil { - return false, err - } - - return true, nil -} - -// CheckRootConstrains returns an error if the certificate violates constraints set on the root (groups, ips, subnets) -func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) error { - // Make sure this cert wasn't valid before the root - if signer.Details.NotAfter.Before(nc.Details.NotAfter) { - return fmt.Errorf("certificate expires after signing certificate") - } - - // Make sure this cert isn't valid after the root - if signer.Details.NotBefore.After(nc.Details.NotBefore) { - return fmt.Errorf("certificate is valid before the signing certificate") - } - - // If the signer has a limited set of groups make sure the cert only contains a subset - if len(signer.Details.InvertedGroups) > 0 { - for _, g := range nc.Details.Groups { - if _, ok := signer.Details.InvertedGroups[g]; !ok { - return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) - } - } - } - - // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset - if len(signer.Details.Ips) > 0 { - for _, ip := range nc.Details.Ips { - if !netMatch(ip, signer.Details.Ips) { - return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", ip.String()) - } - } - } - - // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset - if len(signer.Details.Subnets) > 0 { - for _, subnet := range nc.Details.Subnets { - if !netMatch(subnet, signer.Details.Subnets) { - return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subnet) - } - } - } - - return nil -} - -// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match -func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { - if curve != nc.Details.Curve { - return fmt.Errorf("curve in cert and private key supplied don't match") - } - if nc.Details.IsCA { - switch curve { - case Curve_CURVE25519: - // the call to PublicKey below will panic slice bounds out of range otherwise - if len(key) != ed25519.PrivateKeySize { - return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") - } - - if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return fmt.Errorf("cannot parse private key as P256") - } - pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, nc.Details.PublicKey) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - default: - return fmt.Errorf("invalid curve: %s", curve) - } - return nil - } - - var pub []byte - switch curve { - case Curve_CURVE25519: - var err error - pub, err = curve25519.X25519(key, curve25519.Basepoint) - if err != nil { - return err - } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return err - } - pub = privkey.PublicKey().Bytes() - default: - return fmt.Errorf("invalid curve: %s", curve) - } - if !bytes.Equal(pub, nc.Details.PublicKey) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - - return nil -} - -// String will return a pretty printed representation of a nebula cert -func (nc *NebulaCertificate) String() string { - if nc == nil { - return "NebulaCertificate {}\n" - } - - s := "NebulaCertificate {\n" - s += "\tDetails {\n" - s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name) - - if len(nc.Details.Ips) > 0 { - s += "\t\tIps: [\n" - for _, ip := range nc.Details.Ips { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tIps: []\n" - } - - if len(nc.Details.Subnets) > 0 { - s += "\t\tSubnets: [\n" - for _, ip := range nc.Details.Subnets { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tSubnets: []\n" - } - - if len(nc.Details.Groups) > 0 { - s += "\t\tGroups: [\n" - for _, g := range nc.Details.Groups { - s += fmt.Sprintf("\t\t\t\"%v\"\n", g) - } - s += "\t\t]\n" - } else { - s += "\t\tGroups: []\n" - } - - s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore) - s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter) - s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA) - s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer) - s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey) - s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve) - s += "\t}\n" - fp, err := nc.Sha256Sum() - if err == nil { - s += fmt.Sprintf("\tFingerprint: %s\n", fp) - } - s += fmt.Sprintf("\tSignature: %x\n", nc.Signature) - s += "}" - - return s -} - -// getRawDetails marshals the raw details into protobuf ready struct -func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails { - rd := &RawNebulaCertificateDetails{ - Name: nc.Details.Name, - Groups: nc.Details.Groups, - NotBefore: nc.Details.NotBefore.Unix(), - NotAfter: nc.Details.NotAfter.Unix(), - PublicKey: make([]byte, len(nc.Details.PublicKey)), - IsCA: nc.Details.IsCA, - Curve: nc.Details.Curve, - } - - for _, ipNet := range nc.Details.Ips { - rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask)) - } - - for _, ipNet := range nc.Details.Subnets { - rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask)) - } - - copy(rd.PublicKey, nc.Details.PublicKey[:]) - - // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer) - - return rd -} - -// Marshal will marshal a nebula cert into a protobuf byte array -func (nc *NebulaCertificate) Marshal() ([]byte, error) { - rc := RawNebulaCertificate{ - Details: nc.getRawDetails(), - Signature: nc.Signature, - } - - return proto.Marshal(&rc) -} - -// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result -func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) { - b, err := nc.Marshal() - if err != nil { - return nil, err - } - return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil -} - -// Sha256Sum calculates a sha-256 sum of the marshaled certificate -func (nc *NebulaCertificate) Sha256Sum() (string, error) { - b, err := nc.Marshal() - if err != nil { - return "", err - } - - sum := sha256.Sum256(b) - return hex.EncodeToString(sum[:]), nil -} - -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) { - if !useCache { - return nc.Sha256Sum() - } - - if s := nc.sha256sum.Load(); s != nil { - return *s, nil - } - s, err := nc.Sha256Sum() - if err != nil { - return s, err - } - - nc.sha256sum.Store(&s) - return s, nil -} - -func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) { - toString := func(ips []*net.IPNet) []string { - s := []string{} - for _, ip := range ips { - s = append(s, ip.String()) - } - return s - } - - fp, _ := nc.Sha256Sum() - jc := m{ - "details": m{ - "name": nc.Details.Name, - "ips": toString(nc.Details.Ips), - "subnets": toString(nc.Details.Subnets), - "groups": nc.Details.Groups, - "notBefore": nc.Details.NotBefore, - "notAfter": nc.Details.NotAfter, - "publicKey": fmt.Sprintf("%x", nc.Details.PublicKey), - "isCa": nc.Details.IsCA, - "issuer": nc.Details.Issuer, - "curve": nc.Details.Curve.String(), - }, - "fingerprint": fp, - "signature": fmt.Sprintf("%x", nc.Signature), - } - return json.Marshal(jc) -} - -//func (nc *NebulaCertificate) Copy() *NebulaCertificate { -// r, err := nc.Marshal() -// if err != nil { -// //TODO -// return nil -// } -// -// c, err := UnmarshalNebulaCertificate(r) -// return c -//} - -func (nc *NebulaCertificate) Copy() *NebulaCertificate { - c := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: nc.Details.Name, - Groups: make([]string, len(nc.Details.Groups)), - Ips: make([]*net.IPNet, len(nc.Details.Ips)), - Subnets: make([]*net.IPNet, len(nc.Details.Subnets)), - NotBefore: nc.Details.NotBefore, - NotAfter: nc.Details.NotAfter, - PublicKey: make([]byte, len(nc.Details.PublicKey)), - IsCA: nc.Details.IsCA, - Issuer: nc.Details.Issuer, - InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)), - }, - Signature: make([]byte, len(nc.Signature)), - } - - copy(c.Signature, nc.Signature) - copy(c.Details.Groups, nc.Details.Groups) - copy(c.Details.PublicKey, nc.Details.PublicKey) - - for i, p := range nc.Details.Ips { - c.Details.Ips[i] = &net.IPNet{ - IP: make(net.IP, len(p.IP)), - Mask: make(net.IPMask, len(p.Mask)), - } - copy(c.Details.Ips[i].IP, p.IP) - copy(c.Details.Ips[i].Mask, p.Mask) - } - - for i, p := range nc.Details.Subnets { - c.Details.Subnets[i] = &net.IPNet{ - IP: make(net.IP, len(p.IP)), - Mask: make(net.IPMask, len(p.Mask)), - } - copy(c.Details.Subnets[i].IP, p.IP) - copy(c.Details.Subnets[i].Mask, p.Mask) - } - - for g := range nc.Details.InvertedGroups { - c.Details.InvertedGroups[g] = struct{}{} - } - - return c -} - -func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool { - for _, net := range rootIps { - if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) { - return true - } - } - - return false -} - -func maskContains(caMask, certMask net.IPMask) bool { - caM := maskTo4(caMask) - cM := maskTo4(certMask) - // Make sure forcing to ipv4 didn't nuke us - if caM == nil || cM == nil { - return false - } - - // Make sure the cert mask is not greater than the ca mask - for i := 0; i < len(caMask); i++ { - if caM[i] > cM[i] { - return false - } - } - - return true -} - -func maskTo4(ip net.IPMask) net.IPMask { - if len(ip) == net.IPv4len { - return ip - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16] - } - - return nil -} - -func isZeros(b []byte) bool { - for i := 0; i < len(b); i++ { - if b[i] != 0 { - return false - } - } - return true -} - -func ip2int(ip []byte) uint32 { - if len(ip) == 16 { - return binary.BigEndian.Uint32(ip[12:16]) - } - return binary.BigEndian.Uint32(ip) -} - -func int2ip(nn uint32) net.IP { - ip := make(net.IP, net.IPv4len) - binary.BigEndian.PutUint32(ip, nn) - return ip + return c, nil } diff --git a/cert/cert_test.go b/cert/cert_test.go deleted file mode 100644 index 30e99ec..0000000 --- a/cert/cert_test.go +++ /dev/null @@ -1,1230 +0,0 @@ -package cert - -import ( - "crypto/ecdh" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "fmt" - "io" - "net" - "testing" - "time" - - "github.com/slackhq/nebula/test" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" - "google.golang.org/protobuf/proto" -) - -func TestMarshalingNebulaCertificate(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - Signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) - - nc2, err := UnmarshalNebulaCertificate(b) - assert.Nil(t, err) - - assert.Equal(t, nc.Signature, nc2.Signature) - assert.Equal(t, nc.Details.Name, nc2.Details.Name) - assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore) - assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter) - assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey) - assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA) - - // IP byte arrays can be 4 or 16 in length so we have to go this route - assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips)) - for i, wIp := range nc.Details.Ips { - assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String()) - } - - assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets)) - for i, wIp := range nc.Details.Subnets { - assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String()) - } - - assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups) -} - -func TestNebulaCertificate_Sign(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - } - - pub, priv, err := ed25519.GenerateKey(rand.Reader) - assert.Nil(t, err) - assert.False(t, nc.CheckSignature(pub)) - assert.Nil(t, nc.Sign(Curve_CURVE25519, priv)) - assert.True(t, nc.CheckSignature(pub)) - - _, err = nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) -} - -func TestNebulaCertificate_SignP256(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Curve: Curve_P256, - Issuer: "1234567890abcedfghij1234567890ab", - }, - } - - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) - rawPriv := priv.D.FillBytes(make([]byte, 32)) - - assert.Nil(t, err) - assert.False(t, nc.CheckSignature(pub)) - assert.Nil(t, nc.Sign(Curve_P256, rawPriv)) - assert.True(t, nc.CheckSignature(pub)) - - _, err = nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) -} - -func TestNebulaCertificate_Expired(t *testing.T) { - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), - NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), - }, - } - - assert.True(t, nc.Expired(time.Now().Add(time.Hour))) - assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) - assert.False(t, nc.Expired(time.Now())) -} - -func TestNebulaCertificate_MarshalJSON(t *testing.T) { - time.Local = time.UTC - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), - NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - Signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.MarshalJSON() - assert.Nil(t, err) - assert.Equal( - t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", - string(b), - ) -} - -func TestNebulaCertificate_Verify(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - h, err := ca.Sha256Sum() - assert.Nil(t, err) - - caPool := NewCAPool() - caPool.CAs[h] = ca - - f, err := c.Sha256Sum() - assert.Nil(t, err) - caPool.BlocklistFingerprint(f) - - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is in the block list") - - caPool.ResetCertBlocklist() - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) - assert.False(t, v) - assert.EqualError(t, err, "root certificate is expired") - - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is expired") - - // Test group assertion - ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalToPEM() - assert.Nil(t, err) - - caPool = NewCAPool() - caPool.AddCACertificate(caPem) - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) -} - -func TestNebulaCertificate_VerifyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - h, err := ca.Sha256Sum() - assert.Nil(t, err) - - caPool := NewCAPool() - caPool.CAs[h] = ca - - f, err := c.Sha256Sum() - assert.Nil(t, err) - caPool.BlocklistFingerprint(f) - - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is in the block list") - - caPool.ResetCertBlocklist() - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) - assert.False(t, v) - assert.EqualError(t, err, "root certificate is expired") - - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is expired") - - // Test group assertion - ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalToPEM() - assert.Nil(t, err) - - caPool = NewCAPool() - caPool.AddCACertificate(caPem) - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) -} - -func TestNebulaCertificate_Verify_IPs(t *testing.T) { - _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") - _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalToPEM() - assert.Nil(t, err) - - caPool := NewCAPool() - caPool.AddCACertificate(caPem) - - // ip is outside the network - cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}} - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is outside the network reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is within the network but mask is outside - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip is within the network but mask is outside reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip and mask are within the network - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp2, caIp1}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) -} - -func TestNebulaCertificate_Verify_Subnets(t *testing.T) { - _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") - _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalToPEM() - assert.Nil(t, err) - - caPool := NewCAPool() - caPool.AddCACertificate(caPem) - - // ip is outside the network - cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}} - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is outside the network reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is within the network but mask is outside - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip is within the network but mask is outside reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip and mask are within the network - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp2, caIp1}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) - - // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) - assert.Nil(t, err) -} - -func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) - - _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.NotNil(t, err) - - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - err = c.VerifyPrivateKey(Curve_CURVE25519, priv) - assert.Nil(t, err) - - _, priv2 := x25519Keypair() - err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.NotNil(t, err) -} - -func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_P256, caKey) - assert.Nil(t, err) - - _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.NotNil(t, err) - - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - err = c.VerifyPrivateKey(Curve_P256, priv) - assert.Nil(t, err) - - _, priv2 := p256Keypair() - err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.NotNil(t, err) -} - -func TestNewCAPoolFromBytes(t *testing.T) { - noNewLines := ` -# Current provisional, Remove once everything moves over to the real root. ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- -# root-ca01 ------BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF ------END NEBULA CERTIFICATE----- -` - - withNewLines := ` -# Current provisional, Remove once everything moves over to the real root. - ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- - -# root-ca01 - - ------BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF ------END NEBULA CERTIFICATE----- - -` - - expired := ` -# expired certificate ------BEGIN NEBULA CERTIFICATE----- -CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4 -vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie -WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= ------END NEBULA CERTIFICATE----- -` - - p256 := ` -# p256 certificate ------BEGIN NEBULA CERTIFICATE----- -CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2 -6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H -76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC -IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX ------END NEBULA CERTIFICATE----- -` - - rootCA := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula root ca", - }, - } - - rootCA01 := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula root ca 01", - }, - } - - rootCAP256 := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula P256 test", - }, - } - - p, err := NewCAPoolFromBytes([]byte(noNewLines)) - assert.Nil(t, err) - assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - - pp, err := NewCAPoolFromBytes([]byte(withNewLines)) - assert.Nil(t, err) - assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - - // expired cert, no valid certs - ppp, err := NewCAPoolFromBytes([]byte(expired)) - assert.Equal(t, ErrExpired, err) - assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") - - // expired cert, with valid certs - pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...)) - assert.Equal(t, ErrExpired, err) - assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") - assert.Equal(t, len(pppp.CAs), 3) - - ppppp, err := NewCAPoolFromBytes([]byte(p256)) - assert.Nil(t, err) - assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name) - assert.Equal(t, len(ppppp.CAs), 1) -} - -func appendByteSlices(b ...[]byte) []byte { - retSlice := []byte{} - for _, v := range b { - retSlice = append(retSlice, v...) - } - return retSlice -} - -func TestUnmrshalCertPEM(t *testing.T) { - goodCert := []byte(` -# A good cert ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- -`) - badBanner := []byte(`# A bad banner ------BEGIN NOT A NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NOT A NEBULA CERTIFICATE----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB --END NEBULA CERTIFICATE----`) - - certBundle := appendByteSlices(goodCert, badBanner, invalidPem) - - // Success test case - cert, rest, err := UnmarshalNebulaCertificateFromPEM(certBundle) - assert.NotNil(t, cert) - assert.Equal(t, rest, append(badBanner, invalidPem...)) - assert.Nil(t, err) - - // Fail due to invalid banner. - cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest) - assert.Nil(t, cert) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula certificate banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest) - assert.Nil(t, cert) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalSigningPrivateKey(t *testing.T) { - privKey := []byte(`# A good key ------BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA ED25519 PRIVATE KEY----- -`) - privP256Key := []byte(`# A good key ------BEGIN NEBULA ECDSA P256 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA ECDSA P256 PRIVATE KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA ------END NEBULA ED25519 PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NOT A NEBULA PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== --END NEBULA ED25519 PRIVATE KEY-----`) - - keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalSigningPrivateKey(keyBundle) - assert.Len(t, k, 64) - assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) - - // Success test case - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) - - // Fail due to short key - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519/ECDSA private key banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { - passphrase := []byte("DO NOT USE THIS KEY") - privKey := []byte(`# A good key ------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT -oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl -+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB -qrlJ69wer3ZUHFXA ------END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - shortKey := []byte(`# A key which, once decrypted, is too short ------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 -k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe -GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs -rQr3bdH3Oy/WiYU= ------END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner (not encrypted) ------BEGIN NEBULA ED25519 PRIVATE KEY----- -bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG -XgLvodMXZJuaFPssp+WwtA== ------END NEBULA ED25519 PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT -oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl -+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB -qrlJ69wer3ZUHFXA --END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - - keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) - - // Success test case - curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) - assert.Nil(t, err) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Len(t, k, 64) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - - // Fail due to short key - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - - // Fail due to invalid banner - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - - // Fail due to invalid passphrase - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) - assert.EqualError(t, err, "invalid passphrase or corrupt private key") - assert.Nil(t, k) - assert.Equal(t, rest, []byte{}) -} - -func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { - // Having proved that decryption works correctly above, we can test the - // encryption function produces a value which can be decrypted - passphrase := []byte("passphrase") - bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") - kdfParams := NewArgon2Parameters(64*1024, 4, 3) - key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) - assert.Nil(t, err) - - // Verify the "key" can be decrypted successfully - curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) - assert.Len(t, k, 64) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Equal(t, rest, []byte{}) - assert.Nil(t, err) - - // EncryptAndMarshalEd25519PrivateKey does not create any errors itself -} - -func TestUnmarshalPrivateKey(t *testing.T) { - privKey := []byte(`# A good key ------BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA X25519 PRIVATE KEY----- -`) - privP256Key := []byte(`# A good key ------BEGIN NEBULA P256 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA P256 PRIVATE KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA X25519 PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA X25519 PRIVATE KEY-----`) - - keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalPrivateKey(keyBundle) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) - - // Success test case - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) - - // Fail due to short key - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula private key banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalEd25519PublicKey(t *testing.T) { - pubKey := []byte(`# A good key ------BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA ED25519 PUBLIC KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA ED25519 PUBLIC KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PUBLIC KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA ED25519 PUBLIC KEY-----`) - - keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, err := UnmarshalEd25519PublicKey(keyBundle) - assert.Equal(t, len(k), 32) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - - // Fail due to short key - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid ed25519 public key") - - // Fail due to invalid banner - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 public key banner") - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalX25519PublicKey(t *testing.T) { - pubKey := []byte(`# A good key ------BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA X25519 PUBLIC KEY----- -`) - pubP256Key := []byte(`# A good key ------BEGIN NEBULA P256 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA -AAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA P256 PUBLIC KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA X25519 PUBLIC KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PUBLIC KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA X25519 PUBLIC KEY-----`) - - keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalPublicKey(keyBundle) - assert.Equal(t, len(k), 32) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - - // Success test case - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Equal(t, len(k), 65) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - - // Fail due to short key - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper nebula public key banner") - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -// Ensure that upgrading the protobuf library does not change how certificates -// are marshalled, since this would break signature verification -func TestMarshalingNebulaCertificateConsistency(t *testing.T) { - before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) - after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - Signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) - assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) - - b, err = proto.Marshal(nc.getRawDetails()) - assert.Nil(t, err) - //t.Log("Raw cert size:", len(b)) - assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) -} - -func TestNebulaCertificate_Copy(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - cc := c.Copy() - - test.AssertDeepCopyEqual(t, c, cc) -} - -func TestUnmarshalNebulaCertificate(t *testing.T) { - // Test that we don't panic with an invalid certificate (#332) - data := []byte("\x98\x00\x00") - _, err := UnmarshalNebulaCertificate(data) - assert.EqualError(t, err, "encoded Details was nil") -} - -func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = ips - } - - if len(subnets) > 0 { - nc.Details.Subnets = subnets - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(Curve_CURVE25519, priv) - if err != nil { - return nil, nil, nil, err - } - return nc, pub, priv, nil -} - -func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) - rawPriv := priv.D.FillBytes(make([]byte, 32)) - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - Curve: Curve_P256, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = ips - } - - if len(subnets) > 0 { - nc.Details.Subnets = subnets - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(Curve_P256, rawPriv) - if err != nil { - return nil, nil, nil, err - } - return nc, pub, rawPriv, nil -} - -func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { - issuer, err := ca.Sha256Sum() - if err != nil { - return nil, nil, nil, err - } - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - if len(groups) == 0 { - groups = []string{"test-group1", "test-group2", "test-group3"} - } - - if len(ips) == 0 { - ips = []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, - {IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, - {IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, - } - } - - if len(subnets) == 0 { - subnets = []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, - {IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, - {IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, - } - } - - var pub, rawPriv []byte - - switch ca.Details.Curve { - case Curve_CURVE25519: - pub, rawPriv = x25519Keypair() - case Curve_P256: - pub, rawPriv = p256Keypair() - default: - return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve) - } - - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: ips, - Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Curve: ca.Details.Curve, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, - } - - err = nc.Sign(ca.Details.Curve, key) - if err != nil { - return nil, nil, nil, err - } - - return nc, pub, rawPriv, nil -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey -} - -func p256Keypair() ([]byte, []byte) { - privkey, err := ecdh.P256().GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - pubkey := privkey.PublicKey() - return pubkey.Bytes(), privkey.Bytes() -} diff --git a/cert/cert_v1.go b/cert/cert_v1.go new file mode 100644 index 0000000..6bb146f --- /dev/null +++ b/cert/cert_v1.go @@ -0,0 +1,489 @@ +package cert + +import ( + "bytes" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "net" + "net/netip" + "time" + + "golang.org/x/crypto/curve25519" + "google.golang.org/protobuf/proto" +) + +const publicKeyLen = 32 + +type certificateV1 struct { + details detailsV1 + signature []byte +} + +type detailsV1 struct { + name string + networks []netip.Prefix + unsafeNetworks []netip.Prefix + groups []string + notBefore time.Time + notAfter time.Time + publicKey []byte + isCA bool + issuer string + + curve Curve +} + +type m map[string]interface{} + +func (c *certificateV1) Version() Version { + return Version1 +} + +func (c *certificateV1) Curve() Curve { + return c.details.curve +} + +func (c *certificateV1) Groups() []string { + return c.details.groups +} + +func (c *certificateV1) IsCA() bool { + return c.details.isCA +} + +func (c *certificateV1) Issuer() string { + return c.details.issuer +} + +func (c *certificateV1) Name() string { + return c.details.name +} + +func (c *certificateV1) Networks() []netip.Prefix { + return c.details.networks +} + +func (c *certificateV1) NotAfter() time.Time { + return c.details.notAfter +} + +func (c *certificateV1) NotBefore() time.Time { + return c.details.notBefore +} + +func (c *certificateV1) PublicKey() []byte { + return c.details.publicKey +} + +func (c *certificateV1) Signature() []byte { + return c.signature +} + +func (c *certificateV1) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks +} + +func (c *certificateV1) Fingerprint() (string, error) { + b, err := c.Marshal() + if err != nil { + return "", err + } + + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (c *certificateV1) CheckSignature(key []byte) bool { + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return false + } + switch c.details.curve { + case Curve_CURVE25519: + return ed25519.Verify(key, b, c.signature) + case Curve_P256: + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) + default: + return false + } +} + +func (c *certificateV1) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) +} + +func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.details.curve { + return fmt.Errorf("curve in cert and private key supplied don't match") + } + if c.details.isCA { + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") + } + + if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return fmt.Errorf("cannot parse private key as P256: %w", err) + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, c.details.publicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + default: + return fmt.Errorf("invalid curve: %s", curve) + } + return nil + } + + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return err + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return err + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) + } + if !bytes.Equal(pub, c.details.publicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + + return nil +} + +// getRawDetails marshals the raw details into protobuf ready struct +func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails { + rd := &RawNebulaCertificateDetails{ + Name: c.details.name, + Groups: c.details.groups, + NotBefore: c.details.notBefore.Unix(), + NotAfter: c.details.notAfter.Unix(), + PublicKey: make([]byte, len(c.details.publicKey)), + IsCA: c.details.isCA, + Curve: c.details.curve, + } + + for _, ipNet := range c.details.networks { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) + } + + for _, ipNet := range c.details.unsafeNetworks { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) + } + + copy(rd.PublicKey, c.details.publicKey[:]) + + // I know, this is terrible + rd.Issuer, _ = hex.DecodeString(c.details.issuer) + + return rd +} + +func (c *certificateV1) String() string { + b, err := json.MarshalIndent(c.marshalJSON(), "", "\t") + if err != nil { + return fmt.Sprintf("", err) + } + return string(b) +} + +func (c *certificateV1) MarshalForHandshakes() ([]byte, error) { + pubKey := c.details.publicKey + c.details.publicKey = nil + rawCertNoKey, err := c.Marshal() + if err != nil { + return nil, err + } + c.details.publicKey = pubKey + return rawCertNoKey, nil +} + +func (c *certificateV1) Marshal() ([]byte, error) { + rc := RawNebulaCertificate{ + Details: c.getRawDetails(), + Signature: c.signature, + } + + return proto.Marshal(&rc) +} + +func (c *certificateV1) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil +} + +func (c *certificateV1) MarshalJSON() ([]byte, error) { + return json.Marshal(c.marshalJSON()) +} + +func (c *certificateV1) marshalJSON() m { + fp, _ := c.Fingerprint() + return m{ + "version": Version1, + "details": m{ + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "publicKey": fmt.Sprintf("%x", c.details.publicKey), + "isCa": c.details.isCA, + "issuer": c.details.issuer, + "curve": c.details.curve.String(), + }, + "fingerprint": fp, + "signature": fmt.Sprintf("%x", c.Signature()), + } +} + +func (c *certificateV1) Copy() Certificate { + nc := &certificateV1{ + details: detailsV1{ + name: c.details.name, + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + publicKey: make([]byte, len(c.details.publicKey)), + isCA: c.details.isCA, + issuer: c.details.issuer, + curve: c.details.curve, + }, + signature: make([]byte, len(c.signature)), + } + + if c.details.groups != nil { + nc.details.groups = make([]string, len(c.details.groups)) + copy(nc.details.groups, c.details.groups) + } + + if c.details.networks != nil { + nc.details.networks = make([]netip.Prefix, len(c.details.networks)) + copy(nc.details.networks, c.details.networks) + } + + if c.details.unsafeNetworks != nil { + nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) + } + + copy(nc.signature, c.signature) + copy(nc.details.publicKey, c.details.publicKey) + + return nc +} + +func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV1{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + publicKey: t.PublicKey, + isCA: t.IsCA, + curve: t.Curve, + issuer: t.issuer, + } + + return c.validate() +} + +func (c *certificateV1) validate() error { + // Empty names are allowed + + if len(c.details.publicKey) == 0 { + return ErrInvalidPublicKey + } + + // Original v1 rules allowed multiple networks to be present but ignored all but the first one. + // Continue to allow this behavior + if !c.details.isCA && len(c.details.networks) == 0 { + return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network") + } + + for _, network := range c.details.networks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid network: %s", network) + } + + if network.Addr().Is6() { + return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network) + } + + if network.Addr().IsUnspecified() { + return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) + } + } + + for _, network := range c.details.unsafeNetworks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) + } + + if network.Addr().Is6() { + return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) + } + } + + // v1 doesn't bother with sort order or uniqueness of networks or unsafe networks. + // We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered + // unsafe networks would result in a different signature. + + return nil +} + +func (c *certificateV1) marshalForSigning() ([]byte, error) { + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return nil, err + } + return b, nil +} + +func (c *certificateV1) setSignature(b []byte) error { + if len(b) == 0 { + return ErrEmptySignature + } + c.signature = b + return nil +} + +// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert +// if the publicKey is provided here then it is not required to be present in `b` +func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rc RawNebulaCertificate + err := proto.Unmarshal(b, &rc) + if err != nil { + return nil, err + } + + if rc.Details == nil { + return nil, fmt.Errorf("encoded Details was nil") + } + + if len(rc.Details.Ips)%2 != 0 { + return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") + } + + if len(rc.Details.Subnets)%2 != 0 { + return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") + } + + nc := certificateV1{ + details: detailsV1{ + name: rc.Details.Name, + groups: make([]string, len(rc.Details.Groups)), + networks: make([]netip.Prefix, len(rc.Details.Ips)/2), + unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2), + notBefore: time.Unix(rc.Details.NotBefore, 0), + notAfter: time.Unix(rc.Details.NotAfter, 0), + publicKey: make([]byte, len(rc.Details.PublicKey)), + isCA: rc.Details.IsCA, + curve: rc.Details.Curve, + }, + signature: make([]byte, len(rc.Signature)), + } + + copy(nc.signature, rc.Signature) + copy(nc.details.groups, rc.Details.Groups) + nc.details.issuer = hex.EncodeToString(rc.Details.Issuer) + + if len(publicKey) > 0 { + nc.details.publicKey = publicKey + } + + copy(nc.details.publicKey, rc.Details.PublicKey) + + var ip netip.Addr + for i, rawIp := range rc.Details.Ips { + if i%2 == 0 { + ip = int2addr(rawIp) + } else { + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.networks[i/2] = netip.PrefixFrom(ip, ones) + } + } + + for i, rawIp := range rc.Details.Subnets { + if i%2 == 0 { + ip = int2addr(rawIp) + } else { + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones) + } + } + + err = nc.validate() + if err != nil { + return nil, err + } + + return &nc, nil +} + +func ip2int(ip []byte) uint32 { + if len(ip) == 16 { + return binary.BigEndian.Uint32(ip[12:16]) + } + return binary.BigEndian.Uint32(ip) +} + +func int2ip(nn uint32) net.IP { + ip := make(net.IP, net.IPv4len) + binary.BigEndian.PutUint32(ip, nn) + return ip +} + +func addr2int(addr netip.Addr) uint32 { + b := addr.Unmap().As4() + return binary.BigEndian.Uint32(b[:]) +} + +func int2addr(nn uint32) netip.Addr { + ip := [4]byte{} + binary.BigEndian.PutUint32(ip[:], nn) + return netip.AddrFrom4(ip).Unmap() +} diff --git a/cert/cert.pb.go b/cert/cert_v1.pb.go similarity index 62% rename from cert/cert.pb.go rename to cert/cert_v1.pb.go index 3570e07..32de1a0 100644 --- a/cert/cert.pb.go +++ b/cert/cert_v1.pb.go @@ -1,8 +1,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.30.0 +// protoc-gen-go v1.34.2 // protoc v3.21.5 -// source: cert.proto +// source: cert_v1.proto package cert @@ -50,11 +50,11 @@ func (x Curve) String() string { } func (Curve) Descriptor() protoreflect.EnumDescriptor { - return file_cert_proto_enumTypes[0].Descriptor() + return file_cert_v1_proto_enumTypes[0].Descriptor() } func (Curve) Type() protoreflect.EnumType { - return &file_cert_proto_enumTypes[0] + return &file_cert_v1_proto_enumTypes[0] } func (x Curve) Number() protoreflect.EnumNumber { @@ -63,7 +63,7 @@ func (x Curve) Number() protoreflect.EnumNumber { // Deprecated: Use Curve.Descriptor instead. func (Curve) EnumDescriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{0} + return file_cert_v1_proto_rawDescGZIP(), []int{0} } type RawNebulaCertificate struct { @@ -78,7 +78,7 @@ type RawNebulaCertificate struct { func (x *RawNebulaCertificate) Reset() { *x = RawNebulaCertificate{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[0] + mi := &file_cert_v1_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -91,7 +91,7 @@ func (x *RawNebulaCertificate) String() string { func (*RawNebulaCertificate) ProtoMessage() {} func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[0] + mi := &file_cert_v1_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -104,7 +104,7 @@ func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead. func (*RawNebulaCertificate) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{0} + return file_cert_v1_proto_rawDescGZIP(), []int{0} } func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails { @@ -143,7 +143,7 @@ type RawNebulaCertificateDetails struct { func (x *RawNebulaCertificateDetails) Reset() { *x = RawNebulaCertificateDetails{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[1] + mi := &file_cert_v1_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -156,7 +156,7 @@ func (x *RawNebulaCertificateDetails) String() string { func (*RawNebulaCertificateDetails) ProtoMessage() {} func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[1] + mi := &file_cert_v1_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -169,7 +169,7 @@ func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead. func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{1} + return file_cert_v1_proto_rawDescGZIP(), []int{1} } func (x *RawNebulaCertificateDetails) GetName() string { @@ -254,7 +254,7 @@ type RawNebulaEncryptedData struct { func (x *RawNebulaEncryptedData) Reset() { *x = RawNebulaEncryptedData{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[2] + mi := &file_cert_v1_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -267,7 +267,7 @@ func (x *RawNebulaEncryptedData) String() string { func (*RawNebulaEncryptedData) ProtoMessage() {} func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[2] + mi := &file_cert_v1_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -280,7 +280,7 @@ func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead. func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{2} + return file_cert_v1_proto_rawDescGZIP(), []int{2} } func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata { @@ -309,7 +309,7 @@ type RawNebulaEncryptionMetadata struct { func (x *RawNebulaEncryptionMetadata) Reset() { *x = RawNebulaEncryptionMetadata{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[3] + mi := &file_cert_v1_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -322,7 +322,7 @@ func (x *RawNebulaEncryptionMetadata) String() string { func (*RawNebulaEncryptionMetadata) ProtoMessage() {} func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[3] + mi := &file_cert_v1_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -335,7 +335,7 @@ func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead. func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{3} + return file_cert_v1_proto_rawDescGZIP(), []int{3} } func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string { @@ -367,7 +367,7 @@ type RawNebulaArgon2Parameters struct { func (x *RawNebulaArgon2Parameters) Reset() { *x = RawNebulaArgon2Parameters{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[4] + mi := &file_cert_v1_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -380,7 +380,7 @@ func (x *RawNebulaArgon2Parameters) String() string { func (*RawNebulaArgon2Parameters) ProtoMessage() {} func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[4] + mi := &file_cert_v1_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -393,7 +393,7 @@ func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead. func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{4} + return file_cert_v1_proto_rawDescGZIP(), []int{4} } func (x *RawNebulaArgon2Parameters) GetVersion() int32 { @@ -431,87 +431,87 @@ func (x *RawNebulaArgon2Parameters) GetSalt() []byte { return nil } -var File_cert_proto protoreflect.FileDescriptor +var File_cert_v1_proto protoreflect.FileDescriptor -var file_cert_proto_rawDesc = []byte{ - 0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65, - 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, - 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65, - 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, - 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, - 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07, - 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61, - 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, - 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, - 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53, - 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75, - 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a, - 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e, - 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e, - 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69, - 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c, - 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73, - 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, - 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63, - 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, - 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, - 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, - 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, - 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, - 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, - 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, - 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, - 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, - 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, - 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, - 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, - 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, - 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, - 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, - 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65, - 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00, - 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, - 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, +var file_cert_v1_proto_rawDesc = []byte{ + 0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, + 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, + 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, + 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, + 0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, + 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, + 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, + 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, + 0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, + 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, + 0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, + 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, + 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, + 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, + 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, + 0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, + 0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, + 0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, + 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, + 0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, + 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, + 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, + 0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, + 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, + 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, + 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, + 0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, + 0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, + 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, + 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, + 0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, + 0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, + 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, + 0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, + 0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, + 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, + 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( - file_cert_proto_rawDescOnce sync.Once - file_cert_proto_rawDescData = file_cert_proto_rawDesc + file_cert_v1_proto_rawDescOnce sync.Once + file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc ) -func file_cert_proto_rawDescGZIP() []byte { - file_cert_proto_rawDescOnce.Do(func() { - file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData) +func file_cert_v1_proto_rawDescGZIP() []byte { + file_cert_v1_proto_rawDescOnce.Do(func() { + file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData) }) - return file_cert_proto_rawDescData + return file_cert_v1_proto_rawDescData } -var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5) -var file_cert_proto_goTypes = []interface{}{ +var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_cert_v1_proto_goTypes = []any{ (Curve)(0), // 0: cert.Curve (*RawNebulaCertificate)(nil), // 1: cert.RawNebulaCertificate (*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails @@ -519,7 +519,7 @@ var file_cert_proto_goTypes = []interface{}{ (*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata (*RawNebulaArgon2Parameters)(nil), // 5: cert.RawNebulaArgon2Parameters } -var file_cert_proto_depIdxs = []int32{ +var file_cert_v1_proto_depIdxs = []int32{ 2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails 0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve 4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata @@ -531,13 +531,13 @@ var file_cert_proto_depIdxs = []int32{ 0, // [0:4] is the sub-list for field type_name } -func init() { file_cert_proto_init() } -func file_cert_proto_init() { - if File_cert_proto != nil { +func init() { file_cert_v1_proto_init() } +func file_cert_v1_proto_init() { + if File_cert_v1_proto != nil { return } if !protoimpl.UnsafeEnabled { - file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaCertificate); i { case 0: return &v.state @@ -549,7 +549,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaCertificateDetails); i { case 0: return &v.state @@ -561,7 +561,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaEncryptedData); i { case 0: return &v.state @@ -573,7 +573,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaEncryptionMetadata); i { case 0: return &v.state @@ -585,7 +585,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaArgon2Parameters); i { case 0: return &v.state @@ -602,19 +602,19 @@ func file_cert_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_cert_proto_rawDesc, + RawDescriptor: file_cert_v1_proto_rawDesc, NumEnums: 1, NumMessages: 5, NumExtensions: 0, NumServices: 0, }, - GoTypes: file_cert_proto_goTypes, - DependencyIndexes: file_cert_proto_depIdxs, - EnumInfos: file_cert_proto_enumTypes, - MessageInfos: file_cert_proto_msgTypes, + GoTypes: file_cert_v1_proto_goTypes, + DependencyIndexes: file_cert_v1_proto_depIdxs, + EnumInfos: file_cert_v1_proto_enumTypes, + MessageInfos: file_cert_v1_proto_msgTypes, }.Build() - File_cert_proto = out.File - file_cert_proto_rawDesc = nil - file_cert_proto_goTypes = nil - file_cert_proto_depIdxs = nil + File_cert_v1_proto = out.File + file_cert_v1_proto_rawDesc = nil + file_cert_v1_proto_goTypes = nil + file_cert_v1_proto_depIdxs = nil } diff --git a/cert/cert.proto b/cert/cert_v1.proto similarity index 100% rename from cert/cert.proto rename to cert/cert_v1.proto diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go new file mode 100644 index 0000000..8c3fe93 --- /dev/null +++ b/cert/cert_v1_test.go @@ -0,0 +1,218 @@ +package cert + +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func TestCertificateV1_Marshal(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + assert.Nil(t, err) + //t.Log("Cert size:", len(b)) + + nc2, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + + assert.Equal(t, nc.Version(), Version1) + assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, nc.Signature(), nc2.Signature()) + assert.Equal(t, nc.Name(), nc2.Name()) + assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) + assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) + assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) + assert.Equal(t, nc.IsCA(), nc2.IsCA()) + + assert.Equal(t, nc.Networks(), nc2.Networks()) + assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + assert.Equal(t, nc.Groups(), nc2.Groups()) +} + +func TestCertificateV1_Expired(t *testing.T) { + nc := certificateV1{ + details: detailsV1{ + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestCertificateV1_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", + string(b), + ) +} + +func TestCertificateV1_VerifyPrivateKey(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_CURVE25519, curve) + err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) + assert.Nil(t, err) + + _, priv2 := X25519Keypair() + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + assert.NotNil(t, err) +} + +func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_P256, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_P256, curve) + err = c.VerifyPrivateKey(Curve_P256, rawPriv) + assert.Nil(t, err) + + _, priv2 := P256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.NotNil(t, err) +} + +// Ensure that upgrading the protobuf library does not change how certificates +// are marshalled, since this would break signature verification +func TestMarshalingCertificateV1Consistency(t *testing.T) { + before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC) + after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + require.Nil(t, err) + assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) + + b, err = proto.Marshal(nc.getRawDetails()) + assert.Nil(t, err) + assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) +} + +func TestCertificateV1_Copy(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + cc := c.Copy() + test.AssertDeepCopyEqual(t, c, cc) +} + +func TestUnmarshalCertificateV1(t *testing.T) { + // Test that we don't panic with an invalid certificate (#332) + data := []byte("\x98\x00\x00") + _, err := unmarshalCertificateV1(data, nil) + assert.EqualError(t, err, "encoded Details was nil") +} + +func appendByteSlices(b ...[]byte) []byte { + retSlice := []byte{} + for _, v := range b { + retSlice = append(retSlice, v...) + } + return retSlice +} + +func mustParsePrefixUnmapped(s string) netip.Prefix { + prefix := netip.MustParsePrefix(s) + return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) +} diff --git a/cert/cert_v2.asn1 b/cert/cert_v2.asn1 new file mode 100644 index 0000000..f863133 --- /dev/null +++ b/cert/cert_v2.asn1 @@ -0,0 +1,37 @@ +Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN + +Name ::= UTF8String (SIZE (1..253)) +Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum +Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length +Curve ::= ENUMERATED { + curve25519 (0), + p256 (1) +} + +-- The maximum size of a certificate must not exceed 65536 bytes +Certificate ::= SEQUENCE { + details OCTET STRING, + curve Curve DEFAULT curve25519, + publicKey OCTET STRING, + -- signature(details + curve + publicKey) using the appropriate method for curve + signature OCTET STRING +} + +Details ::= SEQUENCE { + name Name, + + -- At least 1 ipv4 or ipv6 address must be present if isCA is false + networks SEQUENCE OF Network OPTIONAL, + unsafeNetworks SEQUENCE OF Network OPTIONAL, + groups SEQUENCE OF Name OPTIONAL, + isCA BOOLEAN DEFAULT false, + notBefore Time, + notAfter Time, + + -- issuer is only required if isCA is false, if isCA is true then it must not be present + issuer OCTET STRING OPTIONAL, + ... + -- New fields can be added below here +} + +END \ No newline at end of file diff --git a/cert/cert_v2.go b/cert/cert_v2.go new file mode 100644 index 0000000..322463e --- /dev/null +++ b/cert/cert_v2.go @@ -0,0 +1,730 @@ +package cert + +import ( + "bytes" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "net/netip" + "slices" + "time" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" + "golang.org/x/crypto/curve25519" +) + +const ( + classConstructed = 0x20 + classContextSpecific = 0x80 + + TagCertDetails = 0 | classConstructed | classContextSpecific + TagCertCurve = 1 | classContextSpecific + TagCertPublicKey = 2 | classContextSpecific + TagCertSignature = 3 | classContextSpecific + + TagDetailsName = 0 | classContextSpecific + TagDetailsNetworks = 1 | classConstructed | classContextSpecific + TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific + TagDetailsGroups = 3 | classConstructed | classContextSpecific + TagDetailsIsCA = 4 | classContextSpecific + TagDetailsNotBefore = 5 | classContextSpecific + TagDetailsNotAfter = 6 | classContextSpecific + TagDetailsIssuer = 7 | classContextSpecific +) + +const ( + // MaxCertificateSize is the maximum length a valid certificate can be + MaxCertificateSize = 65536 + + // MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems + MaxNameLength = 253 + + // MaxNetworkLength is the maximum length a network value can be. + // 16 bytes for an ipv6 address + 1 byte for the prefix length + MaxNetworkLength = 17 +) + +type certificateV2 struct { + details detailsV2 + + // RawDetails contains the entire asn.1 DER encoded Details struct + // This is to benefit forwards compatibility in signature checking. + // signature(RawDetails + Curve + PublicKey) == Signature + rawDetails []byte + curve Curve + publicKey []byte + signature []byte +} + +type detailsV2 struct { + name string + networks []netip.Prefix // MUST BE SORTED + unsafeNetworks []netip.Prefix // MUST BE SORTED + groups []string + isCA bool + notBefore time.Time + notAfter time.Time + issuer string +} + +func (c *certificateV2) Version() Version { + return Version2 +} + +func (c *certificateV2) Curve() Curve { + return c.curve +} + +func (c *certificateV2) Groups() []string { + return c.details.groups +} + +func (c *certificateV2) IsCA() bool { + return c.details.isCA +} + +func (c *certificateV2) Issuer() string { + return c.details.issuer +} + +func (c *certificateV2) Name() string { + return c.details.name +} + +func (c *certificateV2) Networks() []netip.Prefix { + return c.details.networks +} + +func (c *certificateV2) NotAfter() time.Time { + return c.details.notAfter +} + +func (c *certificateV2) NotBefore() time.Time { + return c.details.notBefore +} + +func (c *certificateV2) PublicKey() []byte { + return c.publicKey +} + +func (c *certificateV2) Signature() []byte { + return c.signature +} + +func (c *certificateV2) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks +} + +func (c *certificateV2) Fingerprint() (string, error) { + if len(c.rawDetails) == 0 { + return "", ErrMissingDetails + } + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature) + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (c *certificateV2) CheckSignature(key []byte) bool { + if len(c.rawDetails) == 0 { + return false + } + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + + switch c.curve { + case Curve_CURVE25519: + return ed25519.Verify(key, b, c.signature) + case Curve_P256: + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) + default: + return false + } +} + +func (c *certificateV2) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) +} + +func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.curve { + return ErrPublicPrivateCurveMismatch + } + if c.details.isCA { + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return ErrInvalidPrivateKey + } + + if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return ErrPublicPrivateKeyMismatch + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return ErrInvalidPrivateKey + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch + } + default: + return fmt.Errorf("invalid curve: %s", curve) + } + return nil + } + + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return ErrInvalidPrivateKey + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return ErrInvalidPrivateKey + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) + } + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch + } + + return nil +} + +func (c *certificateV2) String() string { + mb, err := c.marshalJSON() + if err != nil { + return fmt.Sprintf("", err) + } + + b, err := json.MarshalIndent(mb, "", "\t") + if err != nil { + return fmt.Sprintf("", err) + } + return string(b) +} + +func (c *certificateV2) MarshalForHandshakes() ([]byte, error) { + if c.rawDetails == nil { + return nil, ErrEmptyRawDetails + } + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Skipping the curve and public key since those come across in a different part of the handshake + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) Marshal() ([]byte, error) { + if c.rawDetails == nil { + return nil, ErrEmptyRawDetails + } + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Add the curve only if its not the default value + if c.curve != Curve_CURVE25519 { + b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) { + b.AddBytes([]byte{byte(c.curve)}) + }) + } + + // Add the public key if it is not empty + if c.publicKey != nil { + b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) { + b.AddBytes(c.publicKey) + }) + } + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil +} + +func (c *certificateV2) MarshalJSON() ([]byte, error) { + b, err := c.marshalJSON() + if err != nil { + return nil, err + } + return json.Marshal(b) +} + +func (c *certificateV2) marshalJSON() (m, error) { + fp, err := c.Fingerprint() + if err != nil { + return nil, err + } + + return m{ + "details": m{ + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "isCa": c.details.isCA, + "issuer": c.details.issuer, + }, + "version": Version2, + "publicKey": fmt.Sprintf("%x", c.publicKey), + "curve": c.curve.String(), + "fingerprint": fp, + "signature": fmt.Sprintf("%x", c.Signature()), + }, nil +} + +func (c *certificateV2) Copy() Certificate { + nc := &certificateV2{ + details: detailsV2{ + name: c.details.name, + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + isCA: c.details.isCA, + issuer: c.details.issuer, + }, + curve: c.curve, + publicKey: make([]byte, len(c.publicKey)), + signature: make([]byte, len(c.signature)), + rawDetails: make([]byte, len(c.rawDetails)), + } + + if c.details.groups != nil { + nc.details.groups = make([]string, len(c.details.groups)) + copy(nc.details.groups, c.details.groups) + } + + if c.details.networks != nil { + nc.details.networks = make([]netip.Prefix, len(c.details.networks)) + copy(nc.details.networks, c.details.networks) + } + + if c.details.unsafeNetworks != nil { + nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) + } + + copy(nc.rawDetails, c.rawDetails) + copy(nc.signature, c.signature) + copy(nc.publicKey, c.publicKey) + + return nc +} + +func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV2{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + isCA: t.IsCA, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + issuer: t.issuer, + } + c.curve = t.Curve + c.publicKey = t.PublicKey + return c.validate() +} + +func (c *certificateV2) validate() error { + // Empty names are allowed + + if len(c.publicKey) == 0 { + return ErrInvalidPublicKey + } + + if !c.details.isCA && len(c.details.networks) == 0 { + return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network") + } + + hasV4Networks := false + hasV6Networks := false + for _, network := range c.details.networks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid network: %s", network) + } + + if network.Addr().IsUnspecified() { + return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) + } + + if network.Addr().Is4In6() { + return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network) + } + + hasV4Networks = hasV4Networks || network.Addr().Is4() + hasV6Networks = hasV6Networks || network.Addr().Is6() + } + + slices.SortFunc(c.details.networks, comparePrefix) + err := findDuplicatePrefix(c.details.networks) + if err != nil { + return err + } + + for _, network := range c.details.unsafeNetworks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) + } + + if !c.details.isCA { + if network.Addr().Is6() { + if !hasV6Networks { + return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network) + } + } else if network.Addr().Is4() { + if !hasV4Networks { + return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) + } + } + } + } + + slices.SortFunc(c.details.unsafeNetworks, comparePrefix) + err = findDuplicatePrefix(c.details.unsafeNetworks) + if err != nil { + return err + } + + return nil +} + +func (c *certificateV2) marshalForSigning() ([]byte, error) { + d, err := c.details.Marshal() + if err != nil { + return nil, fmt.Errorf("marshalling certificate details failed: %w", err) + } + c.rawDetails = d + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + return b, nil +} + +func (c *certificateV2) setSignature(b []byte) error { + if len(b) == 0 { + return ErrEmptySignature + } + c.signature = b + return nil +} + +func (d *detailsV2) Marshal() ([]byte, error) { + var b cryptobyte.Builder + var err error + + // Details are a structure + b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) { + + // Add the name + b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(d.name)) + }) + + // Add the networks if any exist + if len(d.networks) > 0 { + b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.networks { + sb, innerErr := n.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal network: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add the unsafe networks if any exist + if len(d.unsafeNetworks) > 0 { + b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.unsafeNetworks { + sb, innerErr := n.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add groups if any exist + if len(d.groups) > 0 { + b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) { + for _, group := range d.groups { + b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(group)) + }) + } + }) + } + + // Add IsCA only if true + if d.isCA { + b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) { + b.AddUint8(0xff) + }) + } + + // Add not before + b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore) + + // Add not after + b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter) + + // Add the issuer if present + if d.issuer != "" { + issuerBytes, innerErr := hex.DecodeString(d.issuer) + if innerErr != nil { + err = fmt.Errorf("failed to decode issuer: %w", innerErr) + return + } + b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) { + b.AddBytes(issuerBytes) + }) + } + }) + + if err != nil { + return nil, err + } + + return b.Bytes() +} + +func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) { + l := len(b) + if l == 0 || l > MaxCertificateSize { + return nil, ErrBadFormat + } + + input := cryptobyte.String(b) + // Open the envelope + if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() { + return nil, ErrBadFormat + } + + // Grab the cert details, we need to preserve the tag and length + var rawDetails cryptobyte.String + if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() { + return nil, ErrBadFormat + } + + //Maybe grab the curve + var rawCurve byte + if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) { + return nil, ErrBadFormat + } + curve = Curve(rawCurve) + + // Maybe grab the public key + var rawPublicKey cryptobyte.String + if len(publicKey) > 0 { + rawPublicKey = publicKey + } else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) { + return nil, ErrBadFormat + } + + if len(rawPublicKey) == 0 { + return nil, ErrBadFormat + } + + // Grab the signature + var rawSignature cryptobyte.String + if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() { + return nil, ErrBadFormat + } + + // Finally unmarshal the details + details, err := unmarshalDetails(rawDetails) + if err != nil { + return nil, err + } + + c := &certificateV2{ + details: details, + rawDetails: rawDetails, + curve: curve, + publicKey: rawPublicKey, + signature: rawSignature, + } + + err = c.validate() + if err != nil { + return nil, err + } + + return c, nil +} + +func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { + // Open the envelope + if !b.ReadASN1(&b, TagCertDetails) || b.Empty() { + return detailsV2{}, ErrBadFormat + } + + // Read the name + var name cryptobyte.String + if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength { + return detailsV2{}, ErrBadFormat + } + + // Read the network addresses + var subString cryptobyte.String + var found bool + + if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) { + return detailsV2{}, ErrBadFormat + } + + var networks []netip.Prefix + var val cryptobyte.String + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { + return detailsV2{}, ErrBadFormat + } + + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + networks = append(networks, n) + } + } + + // Read out any unsafe networks + if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) { + return detailsV2{}, ErrBadFormat + } + + var unsafeNetworks []netip.Prefix + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { + return detailsV2{}, ErrBadFormat + } + + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + unsafeNetworks = append(unsafeNetworks, n) + } + } + + // Read out any groups + if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) { + return detailsV2{}, ErrBadFormat + } + + var groups []string + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() { + return detailsV2{}, ErrBadFormat + } + groups = append(groups, string(val)) + } + } + + // Read out IsCA + var isCa bool + if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) { + return detailsV2{}, ErrBadFormat + } + + // Read not before and not after + var notBefore int64 + if !b.ReadASN1Int64WithTag(¬Before, TagDetailsNotBefore) { + return detailsV2{}, ErrBadFormat + } + + var notAfter int64 + if !b.ReadASN1Int64WithTag(¬After, TagDetailsNotAfter) { + return detailsV2{}, ErrBadFormat + } + + // Read issuer + var issuer cryptobyte.String + if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) { + return detailsV2{}, ErrBadFormat + } + + return detailsV2{ + name: string(name), + networks: networks, + unsafeNetworks: unsafeNetworks, + groups: groups, + isCA: isCa, + notBefore: time.Unix(notBefore, 0), + notAfter: time.Unix(notAfter, 0), + issuer: hex.EncodeToString(issuer), + }, nil +} diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go new file mode 100644 index 0000000..3afbcab --- /dev/null +++ b/cert/cert_v2_test.go @@ -0,0 +1,267 @@ +package cert + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "net/netip" + "slices" + "testing" + "time" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCertificateV2_Marshal(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcdef1234567890abcdef", + }, + signature: []byte("1234567890abcdef1234567890abcdef"), + publicKey: pubKey, + } + + db, err := nc.details.Marshal() + require.NoError(t, err) + nc.rawDetails = db + + b, err := nc.Marshal() + require.Nil(t, err) + //t.Log("Cert size:", len(b)) + + nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) + assert.Nil(t, err) + + assert.Equal(t, nc.Version(), Version2) + assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, nc.Signature(), nc2.Signature()) + assert.Equal(t, nc.Name(), nc2.Name()) + assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) + assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) + assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) + assert.Equal(t, nc.IsCA(), nc2.IsCA()) + assert.Equal(t, nc.Issuer(), nc2.Issuer()) + + // unmarshalling will sort networks and unsafeNetworks, we need to do the same + // but first make sure it fails + assert.NotEqual(t, nc.Networks(), nc2.Networks()) + assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + slices.SortFunc(nc.details.networks, comparePrefix) + slices.SortFunc(nc.details.unsafeNetworks, comparePrefix) + + assert.Equal(t, nc.Networks(), nc2.Networks()) + assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + assert.Equal(t, nc.Groups(), nc2.Groups()) +} + +func TestCertificateV2_Expired(t *testing.T) { + nc := certificateV2{ + details: detailsV2{ + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestCertificateV2_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedf1234567890abcedf") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + isCA: false, + issuer: "1234567890abcedf1234567890abcedf", + }, + publicKey: pubKey, + signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"), + } + + b, err := nc.MarshalJSON() + assert.ErrorIs(t, err, ErrMissingDetails) + + rd, err := nc.details.Marshal() + assert.NoError(t, err) + + nc.rawDetails = rd + b, err = nc.MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", + string(b), + ) +} + +func TestCertificateV2_VerifyPrivateKey(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + _, caKey2, err := ed25519.GenerateKey(rand.Reader) + require.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) + assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + + c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_CURVE25519, curve) + err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) + assert.Nil(t, err) + + _, priv2 := X25519Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) + + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + ac, ok := c.(*certificateV2) + require.True(t, ok) + ac.curve = Curve(99) + err = c.VerifyPrivateKey(Curve(99), priv2) + assert.EqualError(t, err, "invalid curve: 99") + + ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) + + err = c.VerifyPrivateKey(Curve_P256, priv[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + err = c.VerifyPrivateKey(Curve_P256, priv) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + aCa, ok := ca2.(*certificateV2) + require.True(t, ok) + aCa.curve = Curve(99) + err = aCa.VerifyPrivateKey(Curve(99), priv2) + assert.EqualError(t, err, "invalid curve: 99") + +} + +func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_P256, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_P256, curve) + err = c.VerifyPrivateKey(Curve_P256, rawPriv) + assert.Nil(t, err) + + _, priv2 := P256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.NotNil(t, err) +} + +func TestCertificateV2_Copy(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + cc := c.Copy() + test.AssertDeepCopyEqual(t, c, cc) +} + +func TestUnmarshalCertificateV2(t *testing.T) { + data := []byte("\x98\x00\x00") + _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) + assert.EqualError(t, err, "bad wire format") +} + +func TestCertificateV2_marshalForSigningStability(t *testing.T) { + before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC) + after := before.Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcdef1234567890abcdef", + }, + signature: []byte("1234567890abcdef1234567890abcdef"), + publicKey: pubKey, + } + + const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef" + expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr) + require.NoError(t, err) + + db, err := nc.details.Marshal() + require.NoError(t, err) + assert.Equal(t, expectedRawDetails, db) + + expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162") + b, err := nc.marshalForSigning() + require.NoError(t, err) + assert.Equal(t, expectedForSigning, b) +} diff --git a/cert/crypto.go b/cert/crypto.go index 3558e1a..4c236ae 100644 --- a/cert/crypto.go +++ b/cert/crypto.go @@ -3,14 +3,28 @@ package cert import ( "crypto/aes" "crypto/cipher" + "crypto/ed25519" "crypto/rand" + "encoding/pem" "fmt" "io" + "math" "golang.org/x/crypto/argon2" + "google.golang.org/protobuf/proto" ) -// KDF factors +type NebulaEncryptedData struct { + EncryptionMetadata NebulaEncryptionMetadata + Ciphertext []byte +} + +type NebulaEncryptionMetadata struct { + EncryptionAlgorithm string + Argon2Parameters Argon2Parameters +} + +// Argon2Parameters KDF factors type Argon2Parameters struct { version rune Memory uint32 // KiB @@ -19,7 +33,7 @@ type Argon2Parameters struct { salt []byte } -// Returns a new Argon2Parameters object with current version set +// NewArgon2Parameters Returns a new Argon2Parameters object with current version set func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters { return &Argon2Parameters{ version: argon2.Version, @@ -141,3 +155,146 @@ func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) { return blob[:nonceSize], blob[nonceSize:], nil } + +// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key +func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { + ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) + if err != nil { + return nil, err + } + + b, err = proto.Marshal(&RawNebulaEncryptedData{ + EncryptionMetadata: &RawNebulaEncryptionMetadata{ + EncryptionAlgorithm: "AES-256-GCM", + Argon2Parameters: &RawNebulaArgon2Parameters{ + Version: kdfParams.version, + Memory: kdfParams.Memory, + Parallelism: uint32(kdfParams.Parallelism), + Iterations: kdfParams.Iterations, + Salt: kdfParams.salt, + }, + }, + Ciphertext: ciphertext, + }) + if err != nil { + return nil, err + } + + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil + default: + return nil, fmt.Errorf("invalid curve: %v", curve) + } +} + +// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its +// protobuf-generated struct. +func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rned RawNebulaEncryptedData + err := proto.Unmarshal(b, &rned) + if err != nil { + return nil, err + } + + if rned.EncryptionMetadata == nil { + return nil, fmt.Errorf("encoded EncryptionMetadata was nil") + } + + if rned.EncryptionMetadata.Argon2Parameters == nil { + return nil, fmt.Errorf("encoded Argon2Parameters was nil") + } + + params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) + if err != nil { + return nil, err + } + + ned := NebulaEncryptedData{ + EncryptionMetadata: NebulaEncryptionMetadata{ + EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, + Argon2Parameters: *params, + }, + Ciphertext: rned.Ciphertext, + } + + return &ned, nil +} + +func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { + if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { + return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) + } + if params.Memory <= 0 || params.Memory > math.MaxUint32 { + return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { + return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { + return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return &Argon2Parameters{ + version: params.Version, + Memory: params.Memory, + Parallelism: uint8(params.Parallelism), + Iterations: params.Iterations, + salt: params.Salt, + }, nil + +} + +// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with +// the given passphrase, returning any other bytes b or an error on failure +func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) { + var curve Curve + + k, r := pem.Decode(b) + if k == nil { + return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + + switch k.Type { + case EncryptedEd25519PrivateKeyBanner: + curve = Curve_CURVE25519 + case EncryptedECDSAP256PrivateKeyBanner: + curve = Curve_P256 + default: + return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + } + + ned, err := UnmarshalNebulaEncryptedData(k.Bytes) + if err != nil { + return curve, nil, r, err + } + + var bytes []byte + switch ned.EncryptionMetadata.EncryptionAlgorithm { + case "AES-256-GCM": + bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) + if err != nil { + return curve, nil, r, err + } + default: + return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) + } + + switch curve { + case Curve_CURVE25519: + if len(bytes) != ed25519.PrivateKeySize { + return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize) + } + case Curve_P256: + if len(bytes) != 32 { + return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") + } + } + + return curve, bytes, r, nil +} diff --git a/cert/crypto_test.go b/cert/crypto_test.go index c2e61df..c9aba3e 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -23,3 +23,90 @@ func TestNewArgon2Parameters(t *testing.T) { Iterations: 1, }, p) } + +func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { + passphrase := []byte("DO NOT USE THIS KEY") + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + shortKey := []byte(`# A key which, once decrypted, is too short +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 +k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe +GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs +rQr3bdH3Oy/WiYU= +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner (not encrypted) +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG +XgLvodMXZJuaFPssp+WwtA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + + keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) + + // Success test case + curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) + assert.Nil(t, err) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + + // Fail due to invalid banner + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to invalid passphrase + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) + assert.EqualError(t, err, "invalid passphrase or corrupt private key") + assert.Nil(t, k) + assert.Equal(t, rest, []byte{}) +} + +func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { + // Having proved that decryption works correctly above, we can test the + // encryption function produces a value which can be decrypted + passphrase := []byte("passphrase") + bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + kdfParams := NewArgon2Parameters(64*1024, 4, 3) + key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) + assert.Nil(t, err) + + // Verify the "key" can be decrypted successfully + curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) + assert.Len(t, k, 64) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, []byte{}) + assert.Nil(t, err) + + // EncryptAndMarshalEd25519PrivateKey does not create any errors itself +} diff --git a/cert/errors.go b/cert/errors.go index 05b42d1..4bbc023 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -2,13 +2,48 @@ package cert import ( "errors" + "fmt" ) var ( - ErrRootExpired = errors.New("root certificate is expired") - ErrExpired = errors.New("certificate is expired") - ErrNotCA = errors.New("certificate is not a CA") - ErrNotSelfSigned = errors.New("certificate is not self-signed") - ErrBlockListed = errors.New("certificate is in the block list") - ErrSignatureMismatch = errors.New("certificate signature did not match") + ErrBadFormat = errors.New("bad wire format") + ErrRootExpired = errors.New("root certificate is expired") + ErrExpired = errors.New("certificate is expired") + ErrNotCA = errors.New("certificate is not a CA") + ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrBlockListed = errors.New("certificate is in the block list") + ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") + ErrSignatureMismatch = errors.New("certificate signature did not match") + ErrInvalidPublicKey = errors.New("invalid public key") + ErrInvalidPrivateKey = errors.New("invalid private key") + ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve") + ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") + ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") + ErrCaNotFound = errors.New("could not find ca for the certificate") + + ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") + ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") + ErrInvalidPEMX25519PublicKeyBanner = errors.New("bytes did not contain a proper X25519 public key banner") + ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") + ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") + ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") + + ErrNoPeerStaticKey = errors.New("no peer static key was present") + ErrNoPayload = errors.New("provided payload was empty") + + ErrMissingDetails = errors.New("certificate did not contain details") + ErrEmptySignature = errors.New("empty signature") + ErrEmptyRawDetails = errors.New("empty rawDetails not allowed") ) + +type ErrInvalidCertificateProperties struct { + str string +} + +func NewErrInvalidCertificateProperties(format string, a ...any) error { + return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)} +} + +func (e *ErrInvalidCertificateProperties) Error() string { + return e.str +} diff --git a/cert/helper_test.go b/cert/helper_test.go new file mode 100644 index 0000000..1b72a0f --- /dev/null +++ b/cert/helper_test.go @@ -0,0 +1,141 @@ +package cert + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "io" + "net/netip" + "time" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will create a new ca certificate +func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { + var err error + var pub, priv []byte + + switch curve { + case Curve_CURVE25519: + pub, priv, err = ed25519.GenerateKey(rand.Reader) + case Curve_P256: + privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + priv = privk.D.FillBytes(make([]byte, 32)) + default: + // There is no default to allow the underlying lib to respond with an error + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + t := &TBSCertificate{ + Curve: curve, + Version: version, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, curve, priv) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + if len(networks) == 0 { + networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} + } + + var pub, priv []byte + switch curve { + case Curve_CURVE25519: + pub, priv = X25519Keypair() + case Curve_P256: + pub, priv = P256Keypair() + default: + panic("unknown curve") + } + + nc := &TBSCertificate{ + Version: v, + Curve: curve, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem +} + +func X25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} + +func P256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cert/pem.go b/cert/pem.go new file mode 100644 index 0000000..7ad28d1 --- /dev/null +++ b/cert/pem.go @@ -0,0 +1,161 @@ +package cert + +import ( + "encoding/pem" + "fmt" + + "golang.org/x/crypto/ed25519" +) + +const ( + CertificateBanner = "NEBULA CERTIFICATE" + CertificateV2Banner = "NEBULA CERTIFICATE V2" + X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" + X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" + EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" + Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" + Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" + + P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" + P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" + EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" + ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" +) + +// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed +// data or an error on failure +func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { + p, r := pem.Decode(b) + if p == nil { + return nil, r, ErrInvalidPEMBlock + } + + var c Certificate + var err error + + switch p.Type { + // Implementations must validate the resulting certificate contains valid information + case CertificateBanner: + c, err = unmarshalCertificateV1(p.Bytes, nil) + case CertificateV2Banner: + c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) + default: + return nil, r, ErrInvalidPEMCertificateBanner + } + + if err != nil { + return nil, r, err + } + + return c, r, nil + +} + +func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) + default: + return nil + } +} + +func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var expectedLen int + var curve Curve + switch k.Type { + case X25519PublicKeyBanner, Ed25519PublicKeyBanner: + expectedLen = 32 + curve = Curve_CURVE25519 + case P256PublicKeyBanner: + // Uncompressed + expectedLen = 65 + curve = Curve_P256 + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner") + } + if len(k.Bytes) != expectedLen { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve) + } + return k.Bytes, r, curve, nil +} + +func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b}) + default: + return nil + } +} + +func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b}) + default: + return nil + } +} + +// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non +// consumed data or an error on failure +func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var expectedLen int + var curve Curve + switch k.Type { + case X25519PrivateKeyBanner: + expectedLen = 32 + curve = Curve_CURVE25519 + case P256PrivateKeyBanner: + expectedLen = 32 + curve = Curve_P256 + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner") + } + if len(k.Bytes) != expectedLen { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve) + } + return k.Bytes, r, curve, nil +} + +func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var curve Curve + switch k.Type { + case EncryptedEd25519PrivateKeyBanner: + return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted + case EncryptedECDSAP256PrivateKeyBanner: + return nil, nil, Curve_P256, ErrPrivateKeyEncrypted + case Ed25519PrivateKeyBanner: + curve = Curve_CURVE25519 + if len(k.Bytes) != ed25519.PrivateKeySize { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize) + } + case ECDSAP256PrivateKeyBanner: + curve = Curve_P256 + if len(k.Bytes) != 32 { + return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") + } + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner") + } + return k.Bytes, r, curve, nil +} diff --git a/cert/pem_test.go b/cert/pem_test.go new file mode 100644 index 0000000..a0c6e74 --- /dev/null +++ b/cert/pem_test.go @@ -0,0 +1,292 @@ +package cert + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnmarshalCertificateFromPEM(t *testing.T) { + goodCert := []byte(` +# A good cert +-----BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NEBULA CERTIFICATE----- +`) + badBanner := []byte(`# A bad banner +-----BEGIN NOT A NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NOT A NEBULA CERTIFICATE----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-END NEBULA CERTIFICATE----`) + + certBundle := appendByteSlices(goodCert, badBanner, invalidPem) + + // Success test case + cert, rest, err := UnmarshalCertificateFromPEM(certBundle) + assert.NotNil(t, cert) + assert.Equal(t, rest, append(badBanner, invalidPem...)) + assert.Nil(t, err) + + // Fail due to invalid banner. + cert, rest, err = UnmarshalCertificateFromPEM(rest) + assert.Nil(t, cert) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "bytes did not contain a proper certificate banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + cert, rest, err = UnmarshalCertificateFromPEM(rest) + assert.Nil(t, cert) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + privP256Key := []byte(`# A good key +-----BEGIN NEBULA ECDSA P256 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA ECDSA P256 PRIVATE KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NOT A NEBULA PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-END NEBULA ED25519 PRIVATE KEY-----`) + + keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Nil(t, err) + + // Success test case + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + assert.Nil(t, err) + + // Fail due to short key + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { + privKey := []byte(`# A good key +-----BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA X25519 PRIVATE KEY----- +`) + privP256Key := []byte(`# A good key +-----BEGIN NEBULA P256 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PRIVATE KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA X25519 PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA X25519 PRIVATE KEY-----`) + + keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Nil(t, err) + + // Success test case + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + assert.Nil(t, err) + + // Fail due to short key + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "bytes did not contain a proper private key banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalPublicKeyFromPEM(t *testing.T) { + pubKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA ED25519 PUBLIC KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA ED25519 PUBLIC KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PUBLIC KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA ED25519 PUBLIC KEY-----`) + + keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) + assert.Equal(t, 32, len(k)) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Nil(t, err) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + assert.EqualError(t, err, "bytes did not contain a proper public key banner") + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalX25519PublicKey(t *testing.T) { + pubKey := []byte(`# A good key +-----BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA X25519 PUBLIC KEY----- +`) + pubP256Key := []byte(`# A good key +-----BEGIN NEBULA P256 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +AAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PUBLIC KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA X25519 PUBLIC KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PUBLIC KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA X25519 PUBLIC KEY-----`) + + keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) + assert.Equal(t, 32, len(k)) + assert.Nil(t, err) + assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + + // Success test case + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Equal(t, 65, len(k)) + assert.Nil(t, err) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + + // Fail due to short key + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.EqualError(t, err, "bytes did not contain a proper public key banner") + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} diff --git a/cert/sign.go b/cert/sign.go new file mode 100644 index 0000000..12d4ee4 --- /dev/null +++ b/cert/sign.go @@ -0,0 +1,167 @@ +package cert + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "fmt" + "math/big" + "net/netip" + "time" +) + +// TBSCertificate represents a certificate intended to be signed. +// It is invalid to use this structure as a Certificate. +type TBSCertificate struct { + Version Version + Name string + Networks []netip.Prefix + UnsafeNetworks []netip.Prefix + Groups []string + IsCA bool + NotBefore time.Time + NotAfter time.Time + PublicKey []byte + Curve Curve + issuer string +} + +type beingSignedCertificate interface { + // fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation + // Implementations must validate the resulting certificate contains valid information + fromTBSCertificate(*TBSCertificate) error + + // marshalForSigning returns the bytes that should be signed + marshalForSigning() ([]byte, error) + + // setSignature sets the signature for the certificate that has just been signed. The signature must not be blank. + setSignature([]byte) error +} + +type SignerLambda func(certBytes []byte) ([]byte, error) + +// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those +// details do not violate constraints of the signing certificate. +// If the TBSCertificate is a CA then signer must be nil. +func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { + switch t.Curve { + case Curve_CURVE25519: + pk := ed25519.PrivateKey(key) + sp := func(certBytes []byte) ([]byte, error) { + sig := ed25519.Sign(pk, certBytes) + return sig, nil + } + return t.SignWith(signer, curve, sp) + case Curve_P256: + pk := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + pk.X, pk.Y = pk.Curve.ScalarBaseMult(key) + sp := func(certBytes []byte) ([]byte, error) { + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(certBytes) + return ecdsa.SignASN1(rand.Reader, pk, hashed[:]) + } + return t.SignWith(signer, curve, sp) + default: + return nil, fmt.Errorf("invalid curve: %s", t.Curve) + } +} + +// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature. +// You should only use SignWith if you do not have direct access to your private key. +func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) { + if curve != t.Curve { + return nil, fmt.Errorf("curve in cert and private key supplied don't match") + } + + if signer != nil { + if t.IsCA { + return nil, fmt.Errorf("can not sign a CA certificate with another") + } + + err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks) + if err != nil { + return nil, err + } + + issuer, err := signer.Fingerprint() + if err != nil { + return nil, fmt.Errorf("error computing issuer: %v", err) + } + t.issuer = issuer + } else { + if !t.IsCA { + return nil, fmt.Errorf("self signed certificates must have IsCA set to true") + } + } + + var c beingSignedCertificate + switch t.Version { + case Version1: + c = &certificateV1{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } + case Version2: + c = &certificateV2{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unknown cert version %d", t.Version) + } + + certBytes, err := c.marshalForSigning() + if err != nil { + return nil, err + } + + sig, err := sp(certBytes) + if err != nil { + return nil, err + } + + err = c.setSignature(sig) + if err != nil { + return nil, err + } + + sc, ok := c.(Certificate) + if !ok { + return nil, fmt.Errorf("invalid certificate") + } + + return sc, nil +} + +func comparePrefix(a, b netip.Prefix) int { + addr := a.Addr().Compare(b.Addr()) + if addr == 0 { + return a.Bits() - b.Bits() + } + return addr +} + +// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes +func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error { + if len(sortedPrefixes) < 2 { + return nil + } + for i := 1; i < len(sortedPrefixes); i++ { + if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 { + return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i]) + } + } + return nil +} diff --git a/cert/sign_test.go b/cert/sign_test.go new file mode 100644 index 0000000..2b8dbe8 --- /dev/null +++ b/cert/sign_test.go @@ -0,0 +1,90 @@ +package cert + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCertificateV1_Sign(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/24"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) + assert.Nil(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + + b, err := c.Marshal() + assert.Nil(t, err) + uc, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + assert.NotNil(t, uc) +} + +func TestCertificateV1_SignP256(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + Curve: Curve_P256, + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.NoError(t, err) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) + assert.Nil(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + + b, err := c.Marshal() + assert.Nil(t, err) + uc, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + assert.NotNil(t, uc) +} diff --git a/cert_test/cert.go b/cert_test/cert.go new file mode 100644 index 0000000..ebc6f52 --- /dev/null +++ b/cert_test/cert.go @@ -0,0 +1,138 @@ +package cert_test + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "io" + "net/netip" + "time" + + "github.com/slackhq/nebula/cert" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will create a new ca certificate +func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { + var err error + var pub, priv []byte + + switch curve { + case cert.Curve_CURVE25519: + pub, priv, err = ed25519.GenerateKey(rand.Reader) + case cert.Curve_P256: + privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + priv = privk.D.FillBytes(make([]byte, 32)) + default: + // There is no default to allow the underlying lib to respond with an error + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + t := &cert.TBSCertificate{ + Curve: curve, + Version: version, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, curve, priv) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + var pub, priv []byte + switch curve { + case cert.Curve_CURVE25519: + pub, priv = X25519Keypair() + case cert.Curve_P256: + pub, priv = P256Keypair() + default: + panic("unknown curve") + } + + nc := &cert.TBSCertificate{ + Version: v, + Curve: curve, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem +} + +func X25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} + +func P256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 4e5d51d..f83c94f 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -8,13 +8,14 @@ import ( "fmt" "io" "math" - "net" + "net/netip" "os" "strings" "time" "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/ed25519" ) @@ -26,32 +27,43 @@ type caFlags struct { outCertPath *string outQRPath *string groups *string - ips *string - subnets *string + networks *string + unsafeNetworks *string argonMemory *uint argonIterations *uint argonParallelism *uint encryption *bool + version *uint - curve *string + curve *string + p11url *string + + // Deprecated options + ips *string + subnets *string } func newCaFlags() *caFlags { cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} cf.set.Usage = func() {} cf.name = cf.set.String("name", "", "Required: name of the certificate authority") + cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use") cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") - cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") - cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") + cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks") + cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks") cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") + cf.p11url = p11Flag(cf.set) + + cf.ips = cf.set.String("ips", "", "Deprecated, see -networks") + cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &cf } @@ -76,17 +88,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return err } + isP11 := len(*cf.p11url) > 0 + if err := mustFlagString("name", cf.name); err != nil { return err } - if err := mustFlagString("out-key", cf.outKeyPath); err != nil { - return err + if !isP11 { + if err = mustFlagString("out-key", cf.outKeyPath); err != nil { + return err + } } if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err } var kdfParams *cert.Argon2Parameters - if *cf.encryption { + if !isP11 && *cf.encryption { if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil { return err } @@ -106,44 +122,57 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } - var ips []*net.IPNet - if *cf.ips != "" { - for _, rs := range strings.Split(*cf.ips, ",") { + version := cert.Version(*cf.version) + if version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) + } + + var networks []netip.Prefix + if *cf.networks == "" && *cf.ips != "" { + // Pull up deprecated -ips flag if needed + *cf.networks = *cf.ips + } + + if *cf.networks != "" { + for _, rs := range strings.Split(*cf.networks, ",") { rs := strings.Trim(rs, " ") if rs != "" { - ip, ipNet, err := net.ParseCIDR(rs) + n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid ip definition: %s", err) + return newHelpErrorf("invalid -networks definition: %s", rs) } - if ip.To4() == nil { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs) } - - ipNet.IP = ip - ips = append(ips, ipNet) + networks = append(networks, n) } } } - var subnets []*net.IPNet - if *cf.subnets != "" { - for _, rs := range strings.Split(*cf.subnets, ",") { + var unsafeNetworks []netip.Prefix + if *cf.unsafeNetworks == "" && *cf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *cf.unsafeNetworks = *cf.subnets + } + + if *cf.unsafeNetworks != "" { + for _, rs := range strings.Split(*cf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { - _, s, err := net.ParseCIDR(rs) + n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid subnet definition: %s", err) + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } - if s.IP.To4() == nil { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs) } - subnets = append(subnets, s) + unsafeNetworks = append(unsafeNetworks, n) } } } var passphrase []byte - if *cf.encryption { + if !isP11 && *cf.encryption { for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() @@ -166,74 +195,109 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error var curve cert.Curve var pub, rawPriv []byte - switch *cf.curve { - case "25519", "X25519", "Curve25519", "CURVE25519": - curve = cert.Curve_CURVE25519 - pub, rawPriv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return fmt.Errorf("error while generating ed25519 keys: %s", err) - } - case "P256": - var key *ecdsa.PrivateKey - curve = cert.Curve_P256 - key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return fmt.Errorf("error while generating ecdsa keys: %s", err) + var p11Client *pkclient.PKClient + + if isP11 { + switch *cf.curve { + case "P256": + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve) } - // ecdh.PrivateKey lets us get at the encoded bytes, even though - // we aren't using ECDH here. - eKey, err := key.ECDH() + p11Client, err = pkclient.FromUrl(*cf.p11url) if err != nil { - return fmt.Errorf("error while converting ecdsa key: %s", err) + return fmt.Errorf("error while creating PKCS#11 client: %w", err) + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key with PKCS#11: %w", err) + } + } else { + switch *cf.curve { + case "25519", "X25519", "Curve25519", "CURVE25519": + curve = cert.Curve_CURVE25519 + pub, rawPriv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ed25519 keys: %s", err) + } + case "P256": + var key *ecdsa.PrivateKey + curve = cert.Curve_P256 + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ecdsa keys: %s", err) + } + + // ecdh.PrivateKey lets us get at the encoded bytes, even though + // we aren't using ECDH here. + eKey, err := key.ECDH() + if err != nil { + return fmt.Errorf("error while converting ecdsa key: %s", err) + } + rawPriv = eKey.Bytes() + pub = eKey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", *cf.curve) } - rawPriv = eKey.Bytes() - pub = eKey.PublicKey().Bytes() } - nc := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: *cf.name, - Groups: groups, - Ips: ips, - Subnets: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*cf.duration), - PublicKey: pub, - IsCA: true, - Curve: curve, - }, + t := &cert.TBSCertificate{ + Version: version, + Name: *cf.name, + Groups: groups, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + NotBefore: time.Now(), + NotAfter: time.Now().Add(*cf.duration), + PublicKey: pub, + IsCA: true, + Curve: curve, } - if _, err := os.Stat(*cf.outKeyPath); err == nil { - return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) + if !isP11 { + if _, err := os.Stat(*cf.outKeyPath); err == nil { + return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) + } } if _, err := os.Stat(*cf.outCertPath); err == nil { return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) } - err = nc.Sign(curve, rawPriv) - if err != nil { - return fmt.Errorf("error while signing: %s", err) - } - + var c cert.Certificate var b []byte - if *cf.encryption { - b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) + + if isP11 { + c, err = t.SignWith(nil, curve, p11Client.SignASN1) if err != nil { - return fmt.Errorf("error while encrypting out-key: %s", err) + return fmt.Errorf("error while signing with PKCS#11: %w", err) } } else { - b = cert.MarshalSigningPrivateKey(curve, rawPriv) + c, err = t.Sign(nil, curve, rawPriv) + if err != nil { + return fmt.Errorf("error while signing: %s", err) + } + + if *cf.encryption { + b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) + if err != nil { + return fmt.Errorf("error while encrypting out-key: %s", err) + } + } else { + b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv) + } + + err = os.WriteFile(*cf.outKeyPath, b, 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } } - err = os.WriteFile(*cf.outKeyPath, b, 0600) - if err != nil { - return fmt.Errorf("error while writing out-key: %s", err) - } - - b, err = nc.MarshalToPEM() + b, err = c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 3a53405..9da0ad4 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -16,8 +16,6 @@ import ( "github.com/stretchr/testify/assert" ) -//TODO: test file permissions - func Test_caSummary(t *testing.T) { assert.Equal(t, "ca : create a self signed certificate authority", caSummary()) } @@ -43,17 +41,24 @@ func Test_caHelp(t *testing.T) { " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+ + " Deprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the certificate authority\n"+ + " -networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ " -out-key string\n"+ " \tOptional: path to write the private key to (default \"ca.key\")\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use (default 2)\n", ob.String(), ) } @@ -82,25 +87,25 @@ func Test_ca(t *testing.T) { // required args assertHelpError(t, ca( - []string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, + []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, ), "-name is required") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only ips - assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only subnets - assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // failed key write ob.Reset() eb.Reset() - args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} + args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -108,12 +113,12 @@ func Test_ca(t *testing.T) { // create temp key file keyF, err := os.CreateTemp("", "test.key") assert.Nil(t, err) - os.Remove(keyF.Name()) + assert.Nil(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -121,45 +126,46 @@ func Test_ca(t *testing.T) { // create temp cert file crtF, err := os.CreateTemp("", "test.crt") assert.Nil(t, err) - os.Remove(crtF.Name()) - os.Remove(keyF.Name()) + assert.Nil(t, os.Remove(crtF.Name())) + assert.Nil(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb) + lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, c) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) + lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) - assert.Equal(t, "test", lCrt.Details.Name) - assert.Len(t, lCrt.Details.Ips, 0) - assert.True(t, lCrt.Details.IsCA) - assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups) - assert.Len(t, lCrt.Details.Subnets, 0) - assert.Len(t, lCrt.Details.PublicKey, 32) - assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore)) - assert.Equal(t, "", lCrt.Details.Issuer) - assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey)) + assert.Equal(t, "test", lCrt.Name()) + assert.Len(t, lCrt.Networks(), 0) + assert.True(t, lCrt.IsCA()) + assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) + assert.Len(t, lCrt.UnsafeNetworks(), 0) + assert.Len(t, lCrt.PublicKey(), 32) + assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) + assert.Equal(t, "", lCrt.Issuer()) + assert.True(t, lCrt.CheckSignature(lCrt.PublicKey())) // test encrypted key os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -187,7 +193,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -197,7 +203,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, "", eb.String()) @@ -207,13 +213,13 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -222,7 +228,7 @@ func Test_ca(t *testing.T) { os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index d94cbf1..496f84c 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -6,6 +6,8 @@ import ( "io" "os" + "github.com/slackhq/nebula/pkclient" + "github.com/slackhq/nebula/cert" ) @@ -13,8 +15,8 @@ type keygenFlags struct { set *flag.FlagSet outKeyPath *string outPubPath *string - - curve *string + curve *string + p11url *string } func newKeygenFlags() *keygenFlags { @@ -23,6 +25,7 @@ func newKeygenFlags() *keygenFlags { cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to") cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to") cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)") + cf.p11url = p11Flag(cf.set) return &cf } @@ -33,32 +36,58 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { return err } - if err := mustFlagString("out-key", cf.outKeyPath); err != nil { - return err + isP11 := len(*cf.p11url) > 0 + + if !isP11 { + if err = mustFlagString("out-key", cf.outKeyPath); err != nil { + return err + } } - if err := mustFlagString("out-pub", cf.outPubPath); err != nil { + if err = mustFlagString("out-pub", cf.outPubPath); err != nil { return err } var pub, rawPriv []byte var curve cert.Curve - switch *cf.curve { - case "25519", "X25519", "Curve25519", "CURVE25519": - pub, rawPriv = x25519Keypair() - curve = cert.Curve_CURVE25519 - case "P256": - pub, rawPriv = p256Keypair() - curve = cert.Curve_P256 - default: - return fmt.Errorf("invalid curve: %s", *cf.curve) + if isP11 { + switch *cf.curve { + case "P256": + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve) + } + } else { + switch *cf.curve { + case "25519", "X25519", "Curve25519", "CURVE25519": + pub, rawPriv = x25519Keypair() + curve = cert.Curve_CURVE25519 + case "P256": + pub, rawPriv = p256Keypair() + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve: %s", *cf.curve) + } } - err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) - if err != nil { - return fmt.Errorf("error while writing out-key: %s", err) + if isP11 { + p11Client, err := pkclient.FromUrl(*cf.p11url) + if err != nil { + return fmt.Errorf("error while creating PKCS#11 client: %w", err) + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key: %w", err) + } + } else { + err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } } - - err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600) + err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) } @@ -72,7 +101,7 @@ func keygenSummary() string { func keygenHelp(out io.Writer) { cf := newKeygenFlags() - out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) + _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) cf.set.SetOutput(out) cf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 9a3b3f3..fcfd77b 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/assert" ) -//TODO: test file permissions - func Test_keygenSummary(t *testing.T) { assert.Equal(t, "keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary()) } @@ -26,7 +24,8 @@ func Test_keygenHelp(t *testing.T) { " -out-key string\n"+ " \tRequired: path to write the private key to\n"+ " -out-pub string\n"+ - " \tRequired: path to write the public key to\n", + " \tRequired: path to write the public key to\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n"), ob.String(), ) } @@ -80,13 +79,15 @@ func Test_keygen(t *testing.T) { // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) + lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) - lPub, b, err := cert.UnmarshalX25519PublicKey(rb) + lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lPub, 32) diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go index 3d0fa1b..f332895 100644 --- a/cmd/nebula-cert/main_test.go +++ b/cmd/nebula-cert/main_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "errors" + "fmt" "io" "os" "testing" @@ -10,8 +11,6 @@ import ( "github.com/stretchr/testify/assert" ) -//TODO: all flag parsing continueOnError will print to stderr on its own currently - func Test_help(t *testing.T) { expected := "Usage of " + os.Args[0] + " :\n" + " Global flags:\n" + @@ -77,8 +76,16 @@ func assertHelpError(t *testing.T, err error, msg string) { case *helpError: // good default: - t.Fatal("err was not a helpError") + t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) } assert.EqualError(t, err, msg) } + +func optionalPkcs11String(msg string) string { + if p11Supported() { + return msg + } else { + return "" + } +} diff --git a/cmd/nebula-cert/p11_cgo.go b/cmd/nebula-cert/p11_cgo.go new file mode 100644 index 0000000..f1f1ec6 --- /dev/null +++ b/cmd/nebula-cert/p11_cgo.go @@ -0,0 +1,15 @@ +//go:build cgo && pkcs11 + +package main + +import ( + "flag" +) + +func p11Supported() bool { + return true +} + +func p11Flag(set *flag.FlagSet) *string { + return set.String("pkcs11", "", "Optional: PKCS#11 URI to an existing private key") +} diff --git a/cmd/nebula-cert/p11_stub.go b/cmd/nebula-cert/p11_stub.go new file mode 100644 index 0000000..5afeaea --- /dev/null +++ b/cmd/nebula-cert/p11_stub.go @@ -0,0 +1,16 @@ +//go:build !cgo || !pkcs11 + +package main + +import ( + "flag" +) + +func p11Supported() bool { + return false +} + +func p11Flag(set *flag.FlagSet) *string { + var ret = "" + return &ret +} diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 746d6a3..30e0965 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -45,28 +45,27 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("unable to read cert; %s", err) } - var c *cert.NebulaCertificate + var c cert.Certificate var qrBytes []byte part := 0 + var jsonCerts []cert.Certificate + for { - c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert) + c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { return fmt.Errorf("error while unmarshaling cert: %s", err) } if *pf.json { - b, _ := json.Marshal(c) - out.Write(b) - out.Write([]byte("\n")) - + jsonCerts = append(jsonCerts, c) } else { - out.Write([]byte(c.String())) - out.Write([]byte("\n")) + _, _ = out.Write([]byte(c.String())) + _, _ = out.Write([]byte("\n")) } if *pf.outQRPath != "" { - b, err := c.MarshalToPEM() + b, err := c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling cert to PEM: %s", err) } @@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { part++ } + if *pf.json { + b, _ := json.Marshal(jsonCerts) + _, _ = out.Write(b) + _, _ = out.Write([]byte("\n")) + } + if *pf.outQRPath != "" { b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5) if err != nil { diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 9fa8a54..86795e4 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -2,6 +2,10 @@ package main import ( "bytes" + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "net/netip" "os" "testing" "time" @@ -68,25 +72,86 @@ func Test_printCert(t *testing.T) { eb.Reset() tf.Truncate(0) tf.Seek(0, 0) - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Groups: []string{"hi"}, - PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - }, - Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - } + ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil) + c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"}) - p, _ := c.MarshalToPEM() + p, _ := c.MarshalPEM() tf.Write(p) tf.Write(p) tf.Write(p) err = printCert([]string{"-path", tf.Name()}, ob, eb) + fp, _ := c.Fingerprint() + pk := hex.EncodeToString(c.PublicKey()) + sig := hex.EncodeToString(c.Signature()) assert.Nil(t, err) assert.Equal( t, - "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n", + //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", + `{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +`, ob.String(), ) assert.Equal(t, "", eb.String()) @@ -96,26 +161,84 @@ func Test_printCert(t *testing.T) { eb.Reset() tf.Truncate(0) tf.Seek(0, 0) - c = cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Groups: []string{"hi"}, - PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - }, - Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - } - - p, _ = c.MarshalToPEM() tf.Write(p) tf.Write(p) tf.Write(p) err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb) + fp, _ = c.Fingerprint() + pk = hex.EncodeToString(c.PublicKey()) + sig = hex.EncodeToString(c.Signature()) assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n", + `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] +`, ob.String(), ) assert.Equal(t, "", eb.String()) } + +// NewTestCaCert will generate a CA cert +func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) { + var err error + if pubKey == nil || privKey == nil { + pubKey, privKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + } + + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: name, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pubKey, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey) + if err != nil { + panic(err) + } + + return c, privKey +} + +func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) { + if before.IsZero() { + before = ca.NotBefore() + } + + if after.IsZero() { + after = ca.NotAfter() + } + + if len(networks) == 0 { + networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} + } + + pub, rawPriv := x25519Keypair() + nc := &cert.TBSCertificate{ + Version: cert.Version1, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), signerKey) + if err != nil { + panic(err) + } + + return c, rawPriv +} diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 35d6446..ebcb592 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -3,50 +3,63 @@ package main import ( "crypto/ecdh" "crypto/rand" + "errors" "flag" "fmt" "io" - "net" + "net/netip" "os" "strings" "time" "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" ) type signFlags struct { - set *flag.FlagSet - caKeyPath *string - caCertPath *string - name *string - ip *string - duration *time.Duration - inPubPath *string - outKeyPath *string - outCertPath *string - outQRPath *string - groups *string - subnets *string + set *flag.FlagSet + version *uint + caKeyPath *string + caCertPath *string + name *string + networks *string + unsafeNetworks *string + duration *time.Duration + inPubPath *string + outKeyPath *string + outCertPath *string + outQRPath *string + groups *string + + p11url *string + + // Deprecated options + ip *string + subnets *string } func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} + sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") - sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert") + sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert") + sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for") sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") - sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for") - return &sf + sf.p11url = p11Flag(sf.set) + sf.ip = sf.set.String("ip", "", "Deprecated, see -networks") + sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks") + return &sf } func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { @@ -56,8 +69,12 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return err } - if err := mustFlagString("ca-key", sf.caKeyPath); err != nil { - return err + isP11 := len(*sf.p11url) > 0 + + if !isP11 { + if err := mustFlagString("ca-key", sf.caKeyPath); err != nil { + return err + } } if err := mustFlagString("ca-crt", sf.caCertPath); err != nil { return err @@ -65,50 +82,67 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if err := mustFlagString("name", sf.name); err != nil { return err } - if err := mustFlagString("ip", sf.ip); err != nil { - return err - } - if *sf.inPubPath != "" && *sf.outKeyPath != "" { + if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } - rawCAKey, err := os.ReadFile(*sf.caKeyPath) - if err != nil { - return fmt.Errorf("error while reading ca-key: %s", err) + var v4Networks []netip.Prefix + var v6Networks []netip.Prefix + if *sf.networks == "" && *sf.ip != "" { + // Pull up deprecated -ip flag if needed + *sf.networks = *sf.ip + } + + if len(*sf.networks) == 0 { + return newHelpErrorf("-networks is required") + } + + version := cert.Version(*sf.version) + if version != 0 && version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) } var curve cert.Curve var caKey []byte - // naively attempt to decode the private key as though it is not encrypted - caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey) - if err == cert.ErrPrivateKeyEncrypted { - // ask for a passphrase until we get one - var passphrase []byte - for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) - passphrase, err = pr.ReadPassword() + if !isP11 { + var rawCAKey []byte + rawCAKey, err := os.ReadFile(*sf.caKeyPath) - if err == ErrNoTerminal { - return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") - } else if err != nil { - return fmt.Errorf("error reading password: %s", err) - } - - if len(passphrase) > 0 { - break - } - } - if len(passphrase) == 0 { - return fmt.Errorf("cannot open encrypted ca-key without passphrase") - } - - curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) if err != nil { - return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + return fmt.Errorf("error while reading ca-key: %s", err) + } + + // naively attempt to decode the private key as though it is not encrypted + caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) + if errors.Is(err, cert.ErrPrivateKeyEncrypted) { + // ask for a passphrase until we get one + var passphrase []byte + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if errors.Is(err, ErrNoTerminal) { + return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading password: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + if len(passphrase) == 0 { + return fmt.Errorf("cannot open encrypted ca-key without passphrase") + } + + curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) + if err != nil { + return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + } + } else if err != nil { + return fmt.Errorf("error while parsing ca-key: %s", err) } - } else if err != nil { - return fmt.Errorf("error while parsing ca-key: %s", err) } rawCACert, err := os.ReadFile(*sf.caCertPath) @@ -116,18 +150,15 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while reading ca-crt: %s", err) } - caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert) + caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert) if err != nil { return fmt.Errorf("error while parsing ca-crt: %s", err) } - if err := caCert.VerifyPrivateKey(curve, caKey); err != nil { - return fmt.Errorf("refusing to sign, root certificate does not match private key") - } - - issuer, err := caCert.Sha256Sum() - if err != nil { - return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err) + if !isP11 { + if err := caCert.VerifyPrivateKey(curve, caKey); err != nil { + return fmt.Errorf("refusing to sign, root certificate does not match private key") + } } if caCert.Expired(time.Now()) { @@ -136,19 +167,53 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) // if no duration is given, expire one second before the root expires if *sf.duration <= 0 { - *sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1 + *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - ip, ipNet, err := net.ParseCIDR(*sf.ip) - if err != nil { - return newHelpErrorf("invalid ip definition: %s", err) - } - if ip.To4() == nil { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip) - } - ipNet.IP = ip + if *sf.networks != "" { + for _, rs := range strings.Split(*sf.networks, ",") { + rs := strings.Trim(rs, " ") + if rs != "" { + n, err := netip.ParsePrefix(rs) + if err != nil { + return newHelpErrorf("invalid -networks definition: %s", rs) + } - groups := []string{} + if n.Addr().Is4() { + v4Networks = append(v4Networks, n) + } else { + v6Networks = append(v6Networks, n) + } + } + } + } + + var v4UnsafeNetworks []netip.Prefix + var v6UnsafeNetworks []netip.Prefix + if *sf.unsafeNetworks == "" && *sf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *sf.unsafeNetworks = *sf.subnets + } + + if *sf.unsafeNetworks != "" { + for _, rs := range strings.Split(*sf.unsafeNetworks, ",") { + rs := strings.Trim(rs, " ") + if rs != "" { + n, err := netip.ParsePrefix(rs) + if err != nil { + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) + } + + if n.Addr().Is4() { + v4UnsafeNetworks = append(v4UnsafeNetworks, n) + } else { + v6UnsafeNetworks = append(v6UnsafeNetworks, n) + } + } + } + } + + var groups []string if *sf.groups != "" { for _, rg := range strings.Split(*sf.groups, ",") { g := strings.TrimSpace(rg) @@ -158,60 +223,43 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - subnets := []*net.IPNet{} - if *sf.subnets != "" { - for _, rs := range strings.Split(*sf.subnets, ",") { - rs := strings.Trim(rs, " ") - if rs != "" { - _, s, err := net.ParseCIDR(rs) - if err != nil { - return newHelpErrorf("invalid subnet definition: %s", err) - } - if s.IP.To4() == nil { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) - } - subnets = append(subnets, s) - } + var pub, rawPriv []byte + var p11Client *pkclient.PKClient + + if isP11 { + curve = cert.Curve_P256 + p11Client, err = pkclient.FromUrl(*sf.p11url) + if err != nil { + return fmt.Errorf("error while creating PKCS#11 client: %w", err) } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) } - var pub, rawPriv []byte if *sf.inPubPath != "" { + var pubCurve cert.Curve rawPub, err := os.ReadFile(*sf.inPubPath) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } - var pubCurve cert.Curve - pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub) + + pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub) if err != nil { return fmt.Errorf("error while parsing in-pub: %s", err) } if pubCurve != curve { return fmt.Errorf("curve of in-pub does not match ca") } + } else if isP11 { + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key with PKCS#11: %w", err) + } } else { pub, rawPriv = newKeypair(curve) } - nc := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: *sf.name, - Ips: []*net.IPNet{ipNet}, - Groups: groups, - Subnets: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*sf.duration), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - Curve: curve, - }, - } - - if err := nc.CheckRootConstrains(caCert); err != nil { - return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err) - } - if *sf.outKeyPath == "" { *sf.outKeyPath = *sf.name + ".key" } @@ -224,25 +272,105 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } - err = nc.Sign(curve, caKey) - if err != nil { - return fmt.Errorf("error while signing: %s", err) + var crts []cert.Certificate + + notBefore := time.Now() + notAfter := notBefore.Add(*sf.duration) + + if version == 0 || version == cert.Version1 { + // Make sure we at least have an ip + if len(v4Networks) != 1 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") + } + + if version == cert.Version1 { + // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses + if len(v6Networks) > 0 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4") + } + + if len(v6UnsafeNetworks) > 0 { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4") + } + } + + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *sf.name, + Networks: []netip.Prefix{v4Networks[0]}, + Groups: groups, + UnsafeNetworks: v4UnsafeNetworks, + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + + crts = append(crts, nc) } - if *sf.inPubPath == "" { + if version == 0 || version == cert.Version2 { + t := &cert.TBSCertificate{ + Version: cert.Version2, + Name: *sf.name, + Networks: append(v4Networks, v6Networks...), + Groups: groups, + UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...), + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + + crts = append(crts, nc) + } + + if !isP11 && *sf.inPubPath == "" { if _, err := os.Stat(*sf.outKeyPath); err == nil { return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) } - err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) + err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } - b, err := nc.MarshalToPEM() - if err != nil { - return fmt.Errorf("error while marshalling certificate: %s", err) + var b []byte + for _, c := range crts { + sb, err := c.MarshalPEM() + if err != nil { + return fmt.Errorf("error while marshalling certificate: %s", err) + } + b = append(b, sb...) } err = os.WriteFile(*sf.outCertPath, b, 0600) diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index adf83a2..466cb8c 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -16,8 +16,6 @@ import ( "golang.org/x/crypto/ed25519" ) -//TODO: test file permissions - func Test_signSummary(t *testing.T) { assert.Equal(t, "sign : create and sign a certificate", signSummary()) } @@ -39,17 +37,24 @@ func Test_signHelp(t *testing.T) { " -in-pub string\n"+ " \tOptional (if out-key not set): path to read a previously generated public key\n"+ " -ip string\n"+ - " \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+ + " \tDeprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the cert, usually a hostname\n"+ + " -networks string\n"+ + " \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to\n"+ " -out-key string\n"+ " \tOptional (if in-pub not set): path to write the private key to\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n", ob.String(), ) } @@ -76,20 +81,20 @@ func Test_signCert(t *testing.T) { // required args assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, - ), "-ip is required") + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-networks is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -97,7 +102,7 @@ func Test_signCert(t *testing.T) { // failed to read key ob.Reset() eb.Reset() - args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key @@ -107,7 +112,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(caKeyF.Name()) - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -116,10 +121,10 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) - caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv)) + caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) // failed to read cert - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -131,26 +136,18 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(caCrtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // write a proper ca cert for later - ca := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute * 200), - PublicKey: caPub, - IsCA: true, - }, - } - b, _ := ca.MarshalToPEM() + ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) + b, _ := ca.MarshalPEM() caCrtF.Write(b) // failed to read pub - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -162,7 +159,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(inPubF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -171,35 +168,42 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPub, _ := x25519Keypair() - inPubF.Write(cert.MarshalX25519PublicKey(inPub)) + inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)) // bad ip cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // bad subnet cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -208,11 +212,11 @@ func Test_signCert(t *testing.T) { caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") assert.Nil(t, err) defer os.Remove(caKeyF2.Name()) - caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2)) + caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -220,7 +224,7 @@ func Test_signCert(t *testing.T) { // failed key write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -233,7 +237,7 @@ func Test_signCert(t *testing.T) { // failed cert write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -247,40 +251,41 @@ func Test_signCert(t *testing.T) { // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) + lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) + lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) - assert.Equal(t, "test", lCrt.Details.Name) - assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String()) - assert.Len(t, lCrt.Details.Ips, 1) - assert.False(t, lCrt.Details.IsCA) - assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups) - assert.Len(t, lCrt.Details.Subnets, 3) - assert.Len(t, lCrt.Details.PublicKey, 32) - assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore)) + assert.Equal(t, "test", lCrt.Name()) + assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) + assert.Len(t, lCrt.Networks(), 1) + assert.False(t, lCrt.IsCA()) + assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) + assert.Len(t, lCrt.UnsafeNetworks(), 3) + assert.Len(t, lCrt.PublicKey(), 32) + assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) sns := []string{} - for _, sn := range lCrt.Details.Subnets { + for _, sn := range lCrt.UnsafeNetworks() { sns = append(sns, sn.String()) } assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns) - issuer, _ := ca.Sha256Sum() - assert.Equal(t, issuer, lCrt.Details.Issuer) + issuer, _ := ca.Fingerprint() + assert.Equal(t, issuer, lCrt.Issuer()) assert.True(t, lCrt.CheckSignature(caPub)) @@ -289,37 +294,39 @@ func Test_signCert(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert file and check pub key matches in-pub rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb) + lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) - assert.Equal(t, lCrt.Details.PublicKey, inPub) + assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -327,14 +334,14 @@ func Test_signCert(t *testing.T) { // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -361,20 +368,12 @@ func Test_signCert(t *testing.T) { b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams) caKeyF.Write(b) - ca = cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute * 200), - PublicKey: caPub, - IsCA: true, - }, - } - b, _ = ca.MarshalToPEM() + ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) + b, _ = ca.MarshalPEM() caCrtF.Write(b) // test with the proper password - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -384,7 +383,7 @@ func Test_signCert(t *testing.T) { eb.Reset() testpw.password = []byte("invalid password") - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -393,7 +392,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) @@ -403,7 +402,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index c955913..bea4d1d 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "io" @@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCACert, err := os.ReadFile(*vf.caPath) if err != nil { - return fmt.Errorf("error while reading ca: %s", err) + return fmt.Errorf("error while reading ca: %w", err) } caPool := cert.NewCAPool() for { - rawCACert, err = caPool.AddCACertificate(rawCACert) + rawCACert, err = caPool.AddCAFromPEM(rawCACert) if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %s", err) + return fmt.Errorf("error while adding ca cert to pool: %w", err) } if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { @@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCert, err := os.ReadFile(*vf.certPath) if err != nil { - return fmt.Errorf("unable to read crt; %s", err) + return fmt.Errorf("unable to read crt: %w", err) + } + var errs []error + for { + if len(rawCert) == 0 { + break + } + c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert) + if err != nil { + return fmt.Errorf("error while parsing crt: %w", err) + } + rawCert = extra + _, err = caPool.VerifyCertificate(time.Now(), c) + if err != nil { + switch { + case errors.Is(err, cert.ErrCaNotFound): + errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err)) + default: + errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err)) + } + } } - c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) - if err != nil { - return fmt.Errorf("error while parsing crt: %s", err) - } - - good, err := c.Verify(time.Now(), caPool) - if !good { - return err - } - - return nil + return errors.Join(errs...) } func verifySummary() string { @@ -80,7 +91,7 @@ func verifySummary() string { func verifyHelp(out io.Writer) { vf := newVerifyFlags() - out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) + _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) vf.set.SetOutput(out) vf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f0f4c78..d94bd1f 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "crypto/rand" + "errors" "os" "testing" "time" @@ -67,17 +68,8 @@ func Test_verify(t *testing.T) { // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) - ca := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test-ca", - NotBefore: time.Now().Add(time.Hour * -1), - NotAfter: time.Now().Add(time.Hour * 2), - PublicKey: caPub, - IsCA: true, - }, - } - ca.Sign(cert.Curve_CURVE25519, caPriv) - b, _ := ca.MarshalToPEM() + ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil) + b, _ := ca.MarshalPEM() caFile.Truncate(0) caFile.Seek(0, 0) caFile.Write(b) @@ -86,7 +78,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError) + assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() @@ -102,22 +94,13 @@ func Test_verify(t *testing.T) { assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path - _, badPriv, _ := ed25519.GenerateKey(rand.Reader) - certPub, _ := x25519Keypair() - signer, _ := ca.Sha256Sum() - crt := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test-cert", - NotBefore: time.Now().Add(time.Hour * -1), - NotAfter: time.Now().Add(time.Hour), - PublicKey: certPub, - IsCA: false, - Issuer: signer, - }, + crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) + // Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature + pub := crt.PublicKey() + for i, _ := range pub { + pub[i] = 0 } - - crt.Sign(cert.Curve_CURVE25519, badPriv) - b, _ = crt.MarshalToPEM() + b, _ = crt.MarshalPEM() certFile.Truncate(0) certFile.Seek(0, 0) certFile.Write(b) @@ -125,11 +108,11 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "certificate signature did not match") + assert.True(t, errors.Is(err, cert.ErrSignatureMismatch)) // verified cert at path - crt.Sign(cert.Curve_CURVE25519, caPriv) - b, _ = crt.MarshalToPEM() + crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) + b, _ = crt.MarshalPEM() certFile.Truncate(0) certFile.Seek(0, 0) certFile.Write(b) diff --git a/config/config_test.go b/config/config_test.go index fa94393..c3a1a73 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -38,9 +38,6 @@ func TestConfig_Load(t *testing.T) { "new": "hi", } assert.Equal(t, expected, c.Settings) - - //TODO: test symlinked file - //TODO: test symlinked directory } func TestConfig_Get(t *testing.T) { diff --git a/connection_manager.go b/connection_manager.go index d2e8616..9d8d071 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case deleteTunnel: if n.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) + n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: @@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) relayFor := oldhostinfo.relayState.CopyAllRelayFor() for _, r := range relayFor { - existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) + existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr) var index uint32 var relayFrom netip.Addr @@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnNet.Addr() - relayTo = existing.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = existing.PeerAddr case ForwardingType: - relayFrom = existing.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = existing.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } @@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) n.relayUsedLock.RUnlock() // The relay doesn't exist at all; create some relay state and send the request. var err error - index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) + index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { n.l.WithError(err).Error("failed to migrate relay to new hostinfo") continue } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnNet.Addr() - relayTo = r.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = r.PeerAddr case ForwardingType: - relayFrom = r.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = r.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } } - //TODO: IPV6-WORK - relayFromB := relayFrom.As4() - relayToB := relayTo.As4() - // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]), - RelayToIp: binary.BigEndian.Uint32(relayToB[:]), } + + switch newhostinfo.GetCert().Certificate.Version() { + case cert.Version1: + if !relayFrom.Is4() { + n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !relayTo.Is4() { + n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := relayFrom.As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = relayTo.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + req.RelayFromAddr = netAddrToProtoAddr(relayFrom) + req.RelayToAddr = netAddrToProtoAddr(relayTo) + default: + newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay") + continue + } + msg, err := req.Marshal() if err != nil { n.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": req.RelayFromIp, - "relayTo": req.RelayToIp, + "relayFrom": req.RelayFromAddr, + "relayTo": req.RelayToAddr, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": newhostinfo.vpnIp}). + "vpnAddrs": newhostinfo.vpnAddrs}). Info("send CreateRelayRequest") } } @@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time return closeTunnel, hostinfo, nil } - primary := n.hostMap.Hosts[hostinfo.vpnIp] + primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false @@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { - // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. - // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. - // The remotes vpn ip is lower than mine. I will not flip. + // Only one side should swap because if both swap then we may never resolve to a single tunnel. + // vpn addr is static across all tunnels for this host pair so lets + // use that to determine if we should consider swapping. + if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { + // Their primary vpn addr is less than mine. Do not swap. return false } - certState := n.intf.pki.GetCertState() - return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) + crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things + // settle down. + return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { n.hostMap.Lock() // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. - if n.hostMap.Hosts[current.vpnIp] == primary { + if n.hostMap.Hosts[current.vpnAddrs[0]] == primary { n.hostMap.unlockedMakePrimary(current) } n.hostMap.Unlock() @@ -436,8 +458,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) - if valid { + caPool := n.intf.pki.GetCAPool() + err := caPool.VerifyCachedCertificate(now, remoteCert) + if err == nil { return false } @@ -446,9 +469,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - fingerprint, _ := remoteCert.Sha256Sum() hostinfo.logger(n.l).WithError(err). - WithField("fingerprint", fingerprint). + WithField("fingerprint", remoteCert.Fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") return true @@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - certState := n.intf.pki.GetCertState() - if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { + cs := n.intf.pki.getCertState() + curCrt := hostinfo.ConnectionState.myCert + myCrt := cs.getCertificate(curCrt.Version()) + if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + // The current tunnel is using the latest certificate and version, no need to rehandshake. return } - n.l.WithField("vpnIp", hostinfo.vpnIp). + n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) + n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index 5f97cad..8e2ef15 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -4,7 +4,6 @@ import ( "context" "crypto/ed25519" "crypto/rand" - "net" "net/netip" "testing" "time" @@ -35,20 +34,19 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -75,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -89,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.out, hostinfo.localIndexId) @@ -106,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -158,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -171,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0]) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded @@ -188,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // We saw traffic, should no longer be pending deletion nc.In(hostinfo.localIndexId) @@ -197,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } // Check if we can disconnect the peer. @@ -206,55 +203,48 @@ func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { now := time.Now() l := test.NewLogger() - ipNet := net.IPNet{ - IP: net.IPv4(172, 1, 1, 2), - Mask: net.IPMask{255, 255, 255, 0}, - } + vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) - caCert := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: now, - NotAfter: now.Add(1 * time.Hour), - IsCA: true, - PublicKey: pubCA, - }, + tbs := &cert.TBSCertificate{ + Version: 1, + Name: "ca", + IsCA: true, + NotBefore: now, + NotAfter: now.Add(1 * time.Hour), + PublicKey: pubCA, } - assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA)) - ncp := &cert.NebulaCAPool{ - CAs: cert.NewCAPool().CAs, - } - ncp.CAs["ca"] = &caCert + caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) + assert.NoError(t, err) + ncp := cert.NewCAPool() + assert.NoError(t, ncp.AddCA(caCert)) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) - peerCert := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host", - Ips: []*net.IPNet{&ipNet}, - Subnets: []*net.IPNet{}, - NotBefore: now, - NotAfter: now.Add(60 * time.Second), - PublicKey: pubCrt, - IsCA: false, - Issuer: "ca", - }, + tbs = &cert.TBSCertificate{ + Version: 1, + Name: "host", + Networks: []netip.Prefix{vpncidr}, + NotBefore: now, + NotAfter: now.Add(60 * time.Second), + PublicKey: pubCrt, } - assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA)) + peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) + assert.NoError(t, err) + + cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, - RawCertificateNoKey: []byte{}, + privateKey: []byte{}, + v1Cert: &dummyCert{}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -280,10 +270,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.connectionManager = nc hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, ConnectionState: &ConnectionState{ - myCert: &cert.NebulaCertificate{}, - peerCert: &peerCert, + myCert: &dummyCert{}, + peerCert: cachedPeerCert, H: &noise.HandshakeState{}, }, } @@ -303,3 +293,114 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { invalid = nc.isInvalidCertificate(nextTick, hostinfo) assert.True(t, invalid) } + +type dummyCert struct { + version cert.Version + curve cert.Curve + groups []string + isCa bool + issuer string + name string + networks []netip.Prefix + notAfter time.Time + notBefore time.Time + publicKey []byte + signature []byte + unsafeNetworks []netip.Prefix +} + +func (d *dummyCert) Version() cert.Version { + return d.version +} + +func (d *dummyCert) Curve() cert.Curve { + return d.curve +} + +func (d *dummyCert) Groups() []string { + return d.groups +} + +func (d *dummyCert) IsCA() bool { + return d.isCa +} + +func (d *dummyCert) Issuer() string { + return d.issuer +} + +func (d *dummyCert) Name() string { + return d.name +} + +func (d *dummyCert) Networks() []netip.Prefix { + return d.networks +} + +func (d *dummyCert) NotAfter() time.Time { + return d.notAfter +} + +func (d *dummyCert) NotBefore() time.Time { + return d.notBefore +} + +func (d *dummyCert) PublicKey() []byte { + return d.publicKey +} + +func (d *dummyCert) Signature() []byte { + return d.signature +} + +func (d *dummyCert) UnsafeNetworks() []netip.Prefix { + return d.unsafeNetworks +} + +func (d *dummyCert) MarshalForHandshakes() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Sign(curve cert.Curve, key []byte) error { + return nil +} + +func (d *dummyCert) CheckSignature(key []byte) bool { + return true +} + +func (d *dummyCert) Expired(t time.Time) bool { + return false +} + +func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error { + return nil +} + +func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error { + return nil +} + +func (d *dummyCert) String() string { + return "" +} + +func (d *dummyCert) Marshal() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) MarshalPEM() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Fingerprint() (string, error) { + return "", nil +} + +func (d *dummyCert) MarshalJSON() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Copy() cert.Certificate { + return d +} diff --git a/connection_state.go b/connection_state.go index 1dd3c8c..faee443 100644 --- a/connection_state.go +++ b/connection_state.go @@ -3,6 +3,7 @@ package nebula import ( "crypto/rand" "encoding/json" + "fmt" "sync" "sync/atomic" @@ -18,50 +19,54 @@ type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState H *noise.HandshakeState - myCert *cert.NebulaCertificate - peerCert *cert.NebulaCertificate + myCert cert.Certificate + peerCert *cert.CachedCertificate initiator bool messageCounter atomic.Uint64 window *Bits writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc - switch certState.Certificate.Details.Curve { + switch crt.Curve() { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: - dhFunc = noiseutil.DHP256 + if cs.pkcs11Backed { + dhFunc = noiseutil.DHP256PKCS11 + } else { + dhFunc = noiseutil.DHP256 + } default: - l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve) - return nil + return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) } - var cs noise.CipherSuite - if cipher == "chachapoly" { - cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) + var ncs noise.CipherSuite + if cs.cipher == "chachapoly" { + ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) } else { - cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) + ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } - static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} + static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} b := NewBits(ReplayWindow) - // Clear out bit 0, we never transmit it and we don't want it showing as packet loss + // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss b.Update(l, 0) hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: cs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - PresharedKey: psk, - PresharedKeyPlacement: pskStage, + CipherSuite: ncs, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: static, + //NOTE: These should come from CertState (pki.go) when we finally implement it + PresharedKey: []byte{}, + PresharedKeyPlacement: 0, }) if err != nil { - return nil + return nil, fmt.Errorf("NewConnectionState: %s", err) } // The queue and ready params prevent a counter race that would happen when @@ -70,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i H: hs, initiator: initiator, window: b, - myCert: certState.Certificate, + myCert: crt, } // always start the counter from 2, as packet 1 and packet 2 are handshake packets. ci.messageCounter.Add(2) - return ci + return ci, nil } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { @@ -85,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "message_counter": cs.messageCounter.Load(), }) } + +func (cs *ConnectionState) Curve() cert.Curve { + return cs.myCert.Curve() +} diff --git a/control.go b/control.go index 3468b35..20dd7fe 100644 --- a/control.go +++ b/control.go @@ -19,9 +19,9 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp netip.Addr) *HostInfo + QueryVpnAddr(vpnAddr netip.Addr) *HostInfo ForEachIndex(each controlEach) - ForEachVpnIp(each controlEach) + ForEachVpnAddr(each controlEach) GetPreferredRanges() []netip.Prefix } @@ -37,15 +37,15 @@ type Control struct { } type ControlHostInfo struct { - VpnIp netip.Addr `json:"vpnIp"` - LocalIndex uint32 `json:"localIndex"` - RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` - Cert *cert.NebulaCertificate `json:"cert"` - MessageCounter uint64 `json:"messageCounter"` - CurrentRemote netip.AddrPort `json:"currentRemote"` - CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` - CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` + VpnAddrs []netip.Addr `json:"vpnAddrs"` + LocalIndex uint32 `json:"localIndex"` + RemoteIndex uint32 `json:"remoteIndex"` + RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` + Cert cert.Certificate `json:"cert"` + MessageCounter uint64 `json:"messageCounter"` + CurrentRemote netip.AddrPort `json:"currentRemote"` + CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` + CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -130,15 +130,18 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found -func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate { - if c.f.myVpnNet.Addr() == vpnIp { - return c.f.pki.GetCertState().Certificate +func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { + _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) + if found { + // Only returning the default certificate since its impossible + // for any other host but ourselves to have more than 1 + return c.f.pki.getCertState().GetDefaultCertificate().Copy() } - hi := c.f.hostMap.QueryVpnIp(vpnIp) + hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } - return hi.GetCert() + return hi.GetCert().Certificate.Copy() } // CreateTunnel creates a new tunnel to the given vpn ip. @@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) { // PrintTunnel creates a new tunnel to the given vpn ip. func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { - hi := c.f.hostMap.QueryVpnIp(vpnIp) + hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } @@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap { return hi.CopyCache() } -// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found +// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found // Caller should take care to Unmap() any 4in6 addresses prior to calling. -func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { +func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos hl = c.f.hostMap } - h := hl.QueryVpnIp(vpnIp) + h := hl.QueryVpnAddr(vpnAddr) if h == nil { return nil } @@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos // SetRemoteForTunnel forces a tunnel to use a specific remote // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return nil } @@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return false } @@ -224,19 +227,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels // the int returned is a count of tunnels closed func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { - //TODO: this is probably better as a function in ConnectionManager or HostMap directly - lighthouses := c.f.lightHouse.GetLighthouses() - shutdown := func(h *HostInfo) { - if excludeLighthouses { - if _, ok := lighthouses[h.vpnIp]; ok { - return - } + if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { + return } c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote). + c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). Debug("Sending close tunnel message") closed++ } @@ -246,7 +244,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { - relayingHosts[relayingHost.vpnIp] = relayingHost + relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost } c.f.hostMap.Unlock() @@ -254,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Hosts map c.f.hostMap.Lock() for _, relayHost := range c.f.hostMap.Indexes { - if _, ok := relayingHosts[relayHost.vpnIp]; !ok { + if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok { hostInfos = append(hostInfos, relayHost) } } @@ -274,9 +272,8 @@ func (c *Control) Device() overlay.Device { } func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { - chi := ControlHostInfo{ - VpnIp: h.vpnIp, + VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)), LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), @@ -285,12 +282,16 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { CurrentRemote: h.remote, } + for i, a := range h.vpnAddrs { + chi.VpnAddrs[i] = a + } + if h.ConnectionState != nil { chi.MessageCounter = h.ConnectionState.messageCounter.Load() } if c := h.GetCert(); c != nil { - chi.Cert = c.Copy() + chi.Cert = c.Certificate.Copy() } return chi @@ -299,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { func listHostMapHosts(hl controlHostLister) []ControlHostInfo { hosts := make([]ControlHostInfo, 0) pr := hl.GetPreferredRanges() - hl.ForEachVpnIp(func(hostinfo *HostInfo) { + hl.ForEachVpnAddr(func(hostinfo *HostInfo) { hosts = append(hosts, copyHostInfo(hostinfo, pr)) }) return hosts diff --git a/control_test.go b/control_test.go index fbf29c0..6ce7083 100644 --- a/control_test.go +++ b/control_test.go @@ -5,7 +5,6 @@ import ( "net/netip" "reflect" "testing" - "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" @@ -14,10 +13,13 @@ import ( ) func TestControl_GetHostInfoByVpnIp(t *testing.T) { + //TODO: CERT-V2 with multiple certificate versions we have a problem with this test + // Some certs versions have different characteristics and each version implements their own Copy() func + // which means this is not a good place to test for exposing memory l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := newHostMap(l, netip.Prefix{}) + hm := newHostMap(l) hm.preferredRanges.Store(&[]netip.Prefix{}) remote1 := netip.MustParseAddrPort("0.0.0.100:4444") @@ -33,42 +35,27 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { Mask: net.IPMask{255, 255, 255, 0}, } - crt := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Ips: []*net.IPNet{&ipNet}, - Subnets: []*net.IPNet{}, - Groups: []string{"default-group"}, - NotBefore: time.Unix(1, 0), - NotAfter: time.Unix(2, 0), - PublicKey: []byte{5, 6, 7, 8}, - IsCA: false, - Issuer: "the-issuer", - InvertedGroups: map[string]struct{}{"default-group": {}}, - }, - Signature: []byte{1, 2, 1, 2, 1, 3}, - } - - remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) - remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) + remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port())) vpnIp, ok := netip.AddrFromSlice(ipNet.IP) assert.True(t, ok) + crt := &dummyCert{} hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ - peerCert: crt, + peerCert: &cert.CachedCertificate{Certificate: crt}, }, remoteIndexId: 200, localIndexId: 201, - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -83,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: vpnIp2, + vpnAddrs: []netip.Addr{vpnIp2}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -98,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(vpnIp, false) + thi := c.GetHostInfoByVpnAddr(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: vpnIp, + VpnAddrs: []netip.Addr{vpnIp}, LocalIndex: 201, RemoteIndex: 200, RemoteAddrs: []netip.AddrPort{remote2, remote1}, @@ -113,14 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assert.EqualValues(t, &expectedInfo, thi) - //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here - //test.AssertDeepCopyEqual(t, &expectedInfo, thi) + test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(vpnIp2, false) + thi = c.GetHostInfoByVpnAddr(vpnIp2, false) }) } diff --git a/control_tester.go b/control_tester.go index d46540f..451dac5 100644 --- a/control_tester.go +++ b/control_tester.go @@ -6,8 +6,6 @@ package nebula import ( "net/netip" - "github.com/slackhq/nebula/cert" - "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" @@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // This is necessary if you did not configure static hosts or are not running a lighthouse func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() if toAddr.Addr().Is4() { - remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) + remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) + remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port())) } } @@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) // This is necessary to inform an initiator of possible relays for communicating with a responder func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps) + remoteList.unlockedSetRelay(vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula @@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { - //TODO: IPV6-WORK - ip := layers.IPv4{ - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), - DstIP: toIp.Unmap().AsSlice(), +func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { + serialize := make([]gopacket.SerializableLayer, 0) + var netLayer gopacket.NetworkLayer + if toAddr.Is6() { + if !fromAddr.Is6() { + panic("Cant send ipv6 to ipv4") + } + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } else { + if !fromAddr.Is4() { + panic("Cant send ipv4 to ipv6") + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip } udp := layers.UDP{ SrcPort: layers.UDPPort(fromPort), DstPort: layers.UDPPort(toPort), } - err := udp.SetNetworkLayerForChecksum(&ip) + err := udp.SetNetworkLayerForChecksum(netLayer) if err != nil { panic(err) } @@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui ComputeChecksums: true, FixLengths: true, } - err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data)) + + serialize = append(serialize, &udp, gopacket.Payload(data)) + err = gopacket.SerializeLayers(buffer, opt, serialize...) if err != nil { panic(err) } @@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() netip.Addr { - return c.f.myVpnNet.Addr() +func (c *Control) GetVpnAddrs() []netip.Addr { + return c.f.myVpnAddrs } func (c *Control) GetUDPAddr() netip.AddrPort { @@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort { } func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) + hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp) if hostinfo == nil { return false } @@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } -func (c *Control) GetCert() *cert.NebulaCertificate { - return c.f.pki.GetCertState().Certificate +func (c *Control) GetCertState() *CertState { + return c.f.pki.getCertState() } func (c *Control) ReHandshake(vpnIp netip.Addr) { diff --git a/dns_server.go b/dns_server.go index 5fea65c..710f6ed 100644 --- a/dns_server.go +++ b/dns_server.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -21,24 +22,39 @@ var dnsAddr string type dnsRecords struct { sync.RWMutex - dnsMap map[string]string - hostMap *HostMap + l *logrus.Logger + dnsMap4 map[string]netip.Addr + dnsMap6 map[string]netip.Addr + hostMap *HostMap + myVpnAddrsTable *bart.Table[struct{}] } -func newDnsRecords(hostMap *HostMap) *dnsRecords { +func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { return &dnsRecords{ - dnsMap: make(map[string]string), - hostMap: hostMap, + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + myVpnAddrsTable: cs.myVpnAddrsTable, } } -func (d *dnsRecords) Query(data string) string { +func (d *dnsRecords) Query(q uint16, data string) netip.Addr { + data = strings.ToLower(data) d.RLock() defer d.RUnlock() - if r, ok := d.dnsMap[strings.ToLower(data)]; ok { - return r + switch q { + case dns.TypeA: + if r, ok := d.dnsMap4[data]; ok { + return r + } + case dns.TypeAAAA: + if r, ok := d.dnsMap6[data]; ok { + return r + } } - return "" + + return netip.Addr{} } func (d *dnsRecords) QueryCert(data string) string { @@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } - hostinfo := d.hostMap.QueryVpnIp(ip) + hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" } @@ -57,43 +73,69 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } - cert := q.Details - c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) - return c + b, err := q.Certificate.MarshalJSON() + if err != nil { + return "" + } + return string(b) } -func (d *dnsRecords) Add(host, data string) { +// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` +func (d *dnsRecords) Add(host string, addresses []netip.Addr) { + host = strings.ToLower(host) d.Lock() defer d.Unlock() - d.dnsMap[strings.ToLower(host)] = data + haveV4 := false + haveV6 := false + for _, addr := range addresses { + if addr.Is4() && !haveV4 { + d.dnsMap4[host] = addr + haveV4 = true + } else if addr.Is6() && !haveV6 { + d.dnsMap6[host] = addr + haveV6 = true + } + if haveV4 && haveV6 { + break + } + } } -func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { + a, _, _ := net.SplitHostPort(addr) + b, err := netip.ParseAddr(a) + if err != nil { + return false + } + + if b.IsLoopback() { + return true + } + + _, found := d.myVpnAddrsTable.Lookup(b) + return found //if we found it in this table, it's good +} + +func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { - case dns.TypeA: - l.Debugf("Query for A %s", q.Name) - ip := dnsR.Query(q.Name) - if ip != "" { - rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) + case dns.TypeA, dns.TypeAAAA: + qType := dns.TypeToString[q.Qtype] + d.l.Debugf("Query for %s %s", qType, q.Name) + ip := d.Query(q.Qtype, q.Name) + if ip.IsValid() { + rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { m.Answer = append(m.Answer, rr) } } case dns.TypeTXT: - a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b, err := netip.ParseAddr(a) - if err != nil { + // We only answer these queries from nebula nodes or localhost + if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - - // We don't answer these queries from non nebula nodes or localhost - //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) - if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { - return - } - l.Debugf("Query for TXT %s", q.Name) - ip := dnsR.QueryCert(q.Name) + d.l.Debugf("Query for TXT %s", q.Name) + ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) if err == nil { @@ -108,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } } -func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false switch r.Opcode { case dns.OpcodeQuery: - parseQuery(l, m, w) + d.parseQuery(m, w) } w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(hostMap) +func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { + dnsR = newDnsRecords(l, cs, hostMap) // attach request handler func - dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { - handleDnsRequest(l, w, r) - }) + dns.HandleFunc(".", dnsR.handleDnsRequest) c.RegisterReloadCallback(func(c *config.C) { reloadDns(l, c) diff --git a/dns_server_test.go b/dns_server_test.go index 69f6ae8..f4643a3 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,23 +1,38 @@ package nebula import ( + "net/netip" "testing" "github.com/miekg/dns" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" ) func TestParsequery(t *testing.T) { - //TODO: This test is basically pointless + l := logrus.New() hostMap := &HostMap{} - ds := newDnsRecords(hostMap) - ds.Add("test.com.com", "1.2.3.4") + ds := newDnsRecords(l, &CertState{}, hostMap) + addrs := []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.5"), + netip.MustParseAddr("fd01::24"), + netip.MustParseAddr("fd01::25"), + } + ds.Add("test.com.com", addrs) - m := new(dns.Msg) + m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) - //parseQuery(m) + m = &dns.Msg{} + m.SetQuestion("test.com.com", dns.TypeAAAA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) } func Test_getDnsServerAddr(t *testing.T) { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 3d42a56..2e7e6e4 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -4,13 +4,17 @@ package e2e import ( - "fmt" "net/netip" + "slices" "testing" "time" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -19,12 +23,12 @@ import ( ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -34,7 +38,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -43,19 +47,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -76,38 +80,31 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() - //TODO: assert hostmaps } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - // The IPs here are chosen on purpose: - // The current remote handling will sort by preference, public, and then lexically. - // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) - - // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -118,10 +115,30 @@ func TestWrongResponderHandshake(t *testing.T) { theirControl.Start() evilControl.Start() - t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + t.Log("Start the handshake process, we will route until we see the evil tunnel closed") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + h := &header.H{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("Evil tunnel is closed, inject the correct udp addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) + assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) + + t.Log("Route until we see the cached packet") r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { - h := &header.H{} err := h.Parse(p.Data) if err != nil { panic(err) @@ -134,25 +151,103 @@ func TestWrongResponderHandshake(t *testing.T) { return router.KeepRouting }) - //TODO: Assert pending hostmap - I should have a correct hostinfo for them now - t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") + + r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) + t.Log("Success!") + myControl.Stop() + theirControl.Stop() +} + +func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) + o := m{ + "static_host_map": m{ + theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()}, + }, + } + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o) + + // Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr. + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl, evilControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + evilControl.Start() + + t.Log("Start the handshake process, we will route until we see the evil tunnel closed") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + h := &header.H{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("Evil tunnel is closed, inject the correct udp addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) + assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) + + t.Log("Route until we see the cached packet") + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if p.To == theirUdpAddr && h.Type == 1 { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("My cached packet should be received by them") + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Test the tunnel with them") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + t.Log("Flush all packets from all controllers") + r.FlushAll() + + t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete - //TODO: assert hostmaps for everyone r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) t.Log("Success!") myControl.Stop() @@ -163,13 +258,13 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -180,8 +275,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -193,14 +288,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -218,7 +313,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -240,13 +335,13 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -257,10 +352,10 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() @@ -268,17 +363,17 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -289,13 +384,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -306,10 +401,10 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") @@ -318,18 +413,18 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -340,15 +435,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -360,31 +455,161 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) - //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it +} + +func TestReestablishRelays(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Ensure packet traversal from them to me via the relay") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + + p = r.RouteForAllUntilTxTun(myControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from them"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) + + // If we break the relay's connection to 'them', 'me' needs to detect and recover the connection + r.Log("Close the tunnel") + relayControl.CloseTunnel(theirVpnIpNet[0].Addr(), true) + + start := len(myControl.GetHostmap().Indexes) + curIndexes := len(myControl.GetHostmap().Indexes) + for curIndexes >= start { + curIndexes = len(myControl.GetHostmap().Indexes) + r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) + + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + return router.RouteAndExit + }) + time.Sleep(2 * time.Second) + } + r.Log("Dead index went away. Woot!") + r.RenderHostmaps("Me removed hostinfo", myControl, relayControl, theirControl) + // Next packet should re-establish a relayed connection and work just great. + + t.Logf("Assert the tunnel...") + for { + t.Log("RouteForAllUntilTxTun") + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p = r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if slices.Compare(v4.SrcIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + if slices.Compare(v4.DstIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("DstIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udp == nil { + t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") + continue + } + data := packet.ApplicationLayer() + if data == nil { + t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") + continue + } + if string(data.Payload()) != "Hi from me" { + t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) + continue + } + t.Log("I found my lost packet. I am so happy.") + break + } + t.Log("Assert the tunnel works the other way, too") + for { + t.Log("RouteForAllUntilTxTun") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + + p = r.RouteForAllUntilTxTun(myControl) + r.Log("Assert the tunnel works") + packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if slices.Compare(v4.DstIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("Dst is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + if slices.Compare(v4.SrcIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udp == nil { + t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") + continue + } + data := packet.ApplicationLayer() + if data == nil { + t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") + continue + } + if string(data.Payload()) != "Hi from them" { + t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) + continue + } + t.Log("I found my lost packet. I am so happy.") + break + } + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) + } func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -396,14 +621,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -414,27 +639,25 @@ func TestStage1RaceRelays(t *testing.T) { myControl.Stop() theirControl.Stop() relayControl.Stop() - // - ////TODO: assert hostmaps } func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -447,16 +670,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -469,7 +692,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -489,7 +712,7 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- @@ -497,26 +720,23 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() relayControl.Stop() - - // - ////TODO: assert hostmaps } func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -528,19 +748,19 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -556,9 +776,9 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break @@ -569,9 +789,9 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break @@ -581,13 +801,13 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -595,7 +815,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -603,7 +823,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -612,15 +832,15 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -632,19 +852,19 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -660,9 +880,9 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break @@ -673,9 +893,9 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break @@ -685,13 +905,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -699,7 +919,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -707,7 +927,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -715,13 +935,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -732,14 +952,14 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -754,9 +974,9 @@ func TestRehandshaking(t *testing.T) { myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + if len(c.Cert.Groups()) != 0 { // We have a new certificate now break } @@ -764,6 +984,7 @@ func TestRehandshaking(t *testing.T) { time.Sleep(time.Second) } + r.Log("Got the new cert") // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(theirConfig.Settings) assert.NoError(t, err) @@ -781,20 +1002,20 @@ func TestRehandshaking(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - assert.Contains(t, c.Cert.Details.Groups, "new group") + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + assert.Contains(t, c.Cert.Groups(), "new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) @@ -811,13 +1032,13 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -828,18 +1049,14 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - - tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - fmt.Println(tt1.LocalIndex, tt2.LocalIndex) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -854,11 +1071,10 @@ func TestRehandshakingLoser(t *testing.T) { theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) - _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] - if theirNewGroup { + if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { break } @@ -882,20 +1098,20 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) @@ -912,13 +1128,13 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -932,8 +1148,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -944,7 +1160,6 @@ func TestRaceRegression(t *testing.T) { myControl.InjectUDPPacket(theirStage1ForMe) theirControl.InjectUDPPacket(myStage1ForThem) - //TODO: ensure stage 2 t.Log("Get both stage 2") myStage2ForThem := myControl.GetFromUDP(true) theirStage2ForMe := theirControl.GetFromUDP(true) @@ -963,14 +1178,48 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop() } -//TODO: test -// Race winner renews and handshakes -// Race loser renews and handshakes -// Does race winner repin the cert to old? -//TODO: add a test with many lies +func TestV2NonPrimaryWithLighthouse(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) + + o := m{ + "static_host_map": m{ + lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()}, + }, + "lighthouse": m{ + "hosts": []string{lhVpnIpNet[1].Addr().String()}, + "local_allow_list": m{ + // Try and block our lighthouse updates from using the actual addresses assigned to this computer + // If we start discovering addresses the test router doesn't know about then test traffic cant flow + "10.0.0.0/24": true, + "::/0": false, + }, + }, + } + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) + theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, lhControl, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + lhControl.Start() + myControl.Start() + theirControl.Start() + + t.Log("Stand up an ipv6 tunnel between me and them") + assert.True(t, myVpnIpNet[1].Addr().Is6()) + assert.True(t, theirVpnIpNet[1].Addr().Is6()) + assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) + + lhControl.Stop() + myControl.Stop() + theirControl.Stop() +} diff --git a/e2e/helpers.go b/e2e/helpers.go deleted file mode 100644 index 71df805..0000000 --- a/e2e/helpers.go +++ /dev/null @@ -1,125 +0,0 @@ -package e2e - -import ( - "crypto/rand" - "io" - "net" - "net/netip" - "time" - - "github.com/slackhq/nebula/cert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" -) - -// NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = make([]*net.IPNet, len(ips)) - for i, ip := range ips { - nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} - } - } - - if len(subnets) > 0 { - nc.Details.Subnets = make([]*net.IPNet, len(subnets)) - for i, ip := range subnets { - nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} - } - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(cert.Curve_CURVE25519, priv) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, priv, pem -} - -// NewTestCert will generate a signed certificate with the provided details. -// Expiry times are defaulted if you do not pass them in -func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - issuer, err := ca.Sha256Sum() - if err != nil { - panic(err) - } - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - pub, rawPriv := x25519Keypair() - ipb := ip.Addr().AsSlice() - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}}, - //Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, - } - - err = nc.Sign(ca.Details.Curve, key) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey -} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 527f55b..e1b7ac2 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -8,6 +8,7 @@ import ( "io" "net/netip" "os" + "strings" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" @@ -26,27 +28,37 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { +func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) - if err != nil { - panic(err) + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) + } + + if len(vpnNetworks) == 0 { + panic("no vpn networks") } var udpAddr netip.AddrPort - if vpnIpNet.Addr().Is4() { - budpIp := vpnIpNet.Addr().As4() + if vpnNetworks[0].Addr().Is4() { + budpIp := vpnNetworks[0].Addr().As4() budpIp[1] -= 128 udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) } else { - budpIp := vpnIpNet.Addr().As16() - budpIp[13] -= 128 + budpIp := vpnNetworks[0].Addr().As16() + // beef for funsies + budpIp[2] = 190 + budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) - caB, err := caCrt.MarshalToPEM() + caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } @@ -88,11 +100,16 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s } if overrides != nil { - err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + final := m{} + err = mergo.Merge(&final, overrides, mergo.WithAppendSlice) if err != nil { panic(err) } - mc = overrides + err = mergo.Merge(&final, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = final } cb, err := yaml.Marshal(mc) @@ -109,7 +126,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s panic(err) } - return control, vpnIpNet, udpAddr, c + return control, vpnNetworks, udpAddr, c } type doneCb func() @@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me - controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) + controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them - controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) + controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) - assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") + //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things + hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) + assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") - hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) - assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") + hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) + assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") // Check that both vpn and real addr are correct - assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") - assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") + assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") + assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B") assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") @@ -160,25 +178,36 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") - - //TODO: Would be nice to assert this memory - //checkIndexes := func(name string, hm *HostMap, hi *HostInfo) { - // hBbyIndex := hmA.Indexes[hBinA.localIndexId] - // assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name) - // assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name) - // - // //TODO: remote indexes are susceptible to collision - // hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId] - // assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name) - // assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name) - //} - // - //// Check hostmap indexes too - //checkIndexes("hmA", hmA, hBinA) - //checkIndexes("hmB", hmB, hAinB) } func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + if toIp.Is6() { + assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) + } else { + assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort) + } +} + +func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + assert.NotNil(t, v6, "No ipv6 data found") + + assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect") + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + assert.NotNil(t, udp, "No udp data found") + + assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect") + assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect") + + data := packet.ApplicationLayer() + assert.NotNil(t, data) + assert.Equal(t, expected, data.Payload(), "Data was incorrect") +} + +func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") @@ -197,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, assert.Equal(t, expected, data.Payload(), "Data was incorrect") } +func getAddrs(ns []netip.Prefix) []netip.Addr { + var a []netip.Addr + for _, n := range ns { + a = append(a, n.Addr()) + } + return a +} + func NewTestLogger() *logrus.Logger { l := logrus.New() diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index c14ab2e..f2805d0 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { var lines []string var globalLines []*edge - clusterName := strings.Trim(c.GetCert().Details.Name, " ") - clusterVpnIp := c.GetCert().Details.Ips[0].IP + crt := c.GetCertState().GetDefaultCertificate() + clusterName := strings.Trim(crt.Name(), " ") + clusterVpnIp := crt.Networks()[0].Addr() r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() @@ -101,8 +102,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { for _, idx := range indexes { hi, ok := hm.Indexes[idx] if ok { - r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) - remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") + r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs()) + remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) _ = hi } diff --git a/e2e/router/router.go b/e2e/router/router.go index 0890570..5e52ed7 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -10,8 +10,8 @@ import ( "os" "path/filepath" "reflect" + "regexp" "sort" - "strings" "sync" "testing" "time" @@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { panic("Duplicate listen address: " + addr.String()) } - r.vpnControls[c.GetVpnIp()] = c + for _, vpnAddr := range c.GetVpnAddrs() { + r.vpnControls[vpnAddr] = c + } + r.controls[addr] = c } @@ -213,11 +216,11 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr.String(), ":", "-", 1) + sanAddr := normalizeName(addr.String()) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", - sanAddr, e.packet.from.GetVpnIp(), sanAddr, + sanAddr, e.packet.from.GetVpnAddrs(), sanAddr, ) } @@ -250,9 +253,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), + normalizeName(p.from.GetUDPAddr().String()), line, - strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), + normalizeName(p.to.GetUDPAddr().String()), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -267,6 +270,11 @@ func (r *R) renderFlow() { } } +func normalizeName(s string) string { + rx := regexp.MustCompile("[\\[\\]\\:]") + return rx.ReplaceAllLiteralString(s, "_") +} + // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered @@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 + return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0 }) s := renderHostmaps(c...) @@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: r.Lock() - c := r.getControl(sender.GetUDPAddr(), p.To, p) + a := sender.GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) c.InjectUDPPacket(p) @@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - c := r.getControl(cm[x].GetUDPAddr(), p.To, p) + a := cm[x].GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic(fmt.Sprintf("No control for udp tx %s", p.To)) } fp := r.unlockedInjectFlow(cm[x], c, p, false) c.InjectUDPPacket(p) @@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C } func (r *R) formatUdpPacket(p *packet) string { - packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) - v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) - if v4 == nil { - panic("not an ipv4 packet") + var packet gopacket.Packet + var srcAddr netip.Addr + + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy) + if packet.ErrorLayer() == nil { + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) + } else { + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) } from := "unknown" - srcAddr, _ := netip.AddrFromSlice(v4.SrcIP) if c, ok := r.vpnControls[srcAddr]; ok { from = c.GetUDPAddr().String() } - udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) - if udp == nil { + udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udpLayer == nil { panic("not a udp packet") } data := packet.ApplicationLayer() return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", - strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), - udp.SrcPort, - udp.DstPort, + normalizeName(from), + normalizeName(p.to.GetUDPAddr().String()), + udpLayer.SrcPort, + udpLayer.DstPort, string(data.Payload()), ) } diff --git a/examples/config.yml b/examples/config.yml index 6354afa..2939090 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -13,6 +13,12 @@ pki: # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. #disconnect_invalid: true + # default_version controls which certificate version is used in handshakes. + # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`. + # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`. + # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. + # default_version: 1 + # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # The syntax is: @@ -285,7 +291,6 @@ tun: # send multiport handshakes. #tx_handshake_delay: 2 -# TODO # Configure logging level logging: # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. @@ -377,10 +382,12 @@ firewall: # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass - # cidr: a remote CIDR, `0.0.0.0/0` is any. - # local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. - # Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate - # if `default_local_cidr_any` is false, otherwise its `any`. + # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. + # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes. + # If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network. + # Otherwise the default is any vpn network assigned to via the certificate. + # `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release. + # If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum diff --git a/firewall.go b/firewall.go index 8a409d2..e9f454d 100644 --- a/firewall.go +++ b/firewall.go @@ -22,7 +22,7 @@ import ( ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error } type conn struct { @@ -51,10 +51,13 @@ type Firewall struct { UDPTimeout time.Duration //linux: 180s max DefaultTimeout time.Duration //linux: 600s - // Used to ensure we don't emit local packets for ips we don't own - localIps *bart.Table[struct{}] - assignedCIDR netip.Prefix - hasSubnets bool + // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. + // The vpn addresses are a full bit match while the unsafe networks only match the prefix + routableNetworks *bart.Table[struct{}] + + // assignedNetworks is a list of vpn networks assigned to us in the certificate. + assignedNetworks []netip.Prefix + hasUnsafeNetworks bool rules string rulesVersion uint16 @@ -67,9 +70,9 @@ type Firewall struct { } type firewallMetrics struct { - droppedLocalIP metrics.Counter - droppedRemoteIP metrics.Counter - droppedNoRule metrics.Counter + droppedLocalAddr metrics.Counter + droppedRemoteAddr metrics.Counter + droppedNoRule metrics.Counter } type FirewallConntrack struct { @@ -126,88 +129,87 @@ type firewallLocalCIDR struct { } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { +// The certificate provided should be the highest version loaded in memory. +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration - var min, max time.Duration + var tmin, tmax time.Duration if tcpTimeout < UDPTimeout { - min = tcpTimeout - max = UDPTimeout + tmin = tcpTimeout + tmax = UDPTimeout } else { - min = UDPTimeout - max = tcpTimeout + tmin = UDPTimeout + tmax = tcpTimeout } - if defaultTimeout < min { - min = defaultTimeout - } else if defaultTimeout > max { - max = defaultTimeout + if defaultTimeout < tmin { + tmin = defaultTimeout + } else if defaultTimeout > tmax { + tmax = defaultTimeout } - localIps := new(bart.Table[struct{}]) - var assignedCIDR netip.Prefix - var assignedSet bool - for _, ip := range c.Details.Ips { - //TODO: IPV6-WORK the unmap is a bit unfortunate - nip, _ := netip.AddrFromSlice(ip.IP) - nip = nip.Unmap() - nprefix := netip.PrefixFrom(nip, nip.BitLen()) - localIps.Insert(nprefix, struct{}{}) - - if !assignedSet { - // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = nprefix - assignedSet = true - } + routableNetworks := new(bart.Table[struct{}]) + var assignedNetworks []netip.Prefix + for _, network := range c.Networks() { + nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) + routableNetworks.Insert(nprefix, struct{}{}) + assignedNetworks = append(assignedNetworks, network) } - for _, n := range c.Details.Subnets { - nip, _ := netip.AddrFromSlice(n.IP) - ones, _ := n.Mask.Size() - nip = nip.Unmap() - localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{}) + hasUnsafeNetworks := false + for _, n := range c.UnsafeNetworks() { + routableNetworks.Insert(n, struct{}{}) + hasUnsafeNetworks = true } return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), - TimerWheel: NewTimerWheel[firewall.Packet](min, max), + TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, - InRules: newFirewallTable(), - OutRules: newFirewallTable(), - TCPTimeout: tcpTimeout, - UDPTimeout: UDPTimeout, - DefaultTimeout: defaultTimeout, - localIps: localIps, - assignedCIDR: assignedCIDR, - hasSubnets: len(c.Details.Subnets) > 0, - l: l, + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, + hasUnsafeNetworks: hasUnsafeNetworks, + l: l, incomingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), }, outgoingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), }, } } -func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { + certificate := cs.getCertificate(cert.Version2) + if certificate == nil { + certificate = cs.getCertificate(cert.Version1) + } + + if certificate == nil { + panic("No certificate available to reconfigure the firewall") + } + fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), - nc, + certificate, //TODO: max_connections ) - //TODO: Flip to false after v1.9 release - fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true) + fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) inboundAction := c.GetString("firewall.inbound_action", "drop") switch inboundAction { @@ -287,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort fp = ft.TCP case firewall.ProtoUDP: fp = ft.UDP - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -421,33 +423,31 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet if f.inConns(fp, h, caPool, localCache) { return nil } // Make sure remote address matches nebula certificate - if remoteCidr := h.remoteCidr; remoteCidr != nil { - //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different - _, ok := remoteCidr.Lookup(fp.RemoteIP) + if h.networks != nil { + _, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { - f.metrics(incoming).droppedRemoteIP.Inc(1) + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } else { - // Simple case: Certificate has one IP and no subnets - if fp.RemoteIP != h.vpnIp { - f.metrics(incoming).droppedRemoteIP.Inc(1) + // Simple case: Certificate has one address and no unsafe networks + if h.vpnAddrs[0] != fp.RemoteAddr { + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } // Make sure we are supposed to be handling this local ip address - //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different - _, ok := f.localIps.Lookup(fp.LocalIP) + _, ok := f.routableNetworks.Lookup(fp.LocalAddr) if !ok { - f.metrics(incoming).droppedLocalIP.Inc(1) + f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } @@ -492,7 +492,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true @@ -619,7 +619,7 @@ func (f *Firewall) evict(p firewall.Packet) { delete(conntrack.Conns, p) } -func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if ft.AnyProto.match(p, incoming, c, caPool) { return true } @@ -633,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC if ft.UDP.match(p, incoming, c, caPool) { return true } - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: if ft.ICMP.match(p, incoming, c, caPool) { return true } @@ -663,7 +663,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou return nil } -func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { // We don't have any allowed ports, bail if fp == nil { return false @@ -726,7 +726,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc return nil } -func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if fc == nil { return false } @@ -735,18 +735,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return true } - if t, ok := fc.CAShas[c.Details.Issuer]; ok { + if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok { if t.match(p, c) { return true } } - s, err := caPool.GetCAForCert(c) + s, err := caPool.GetCAForCert(c.Certificate) if err != nil { return false } - return fc.CANames[s.Details.Name].match(p, c) + return fc.CANames[s.Certificate.Name()].match(p, c) } func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { @@ -826,7 +826,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo return false } -func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool { +func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool { if fr == nil { return false } @@ -841,7 +841,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool found := false for _, g := range sg.Groups { - if _, ok := c.Details.InvertedGroups[g]; !ok { + if _, ok := c.InvertedGroups[g]; !ok { found = false break } @@ -855,42 +855,44 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } if fr.Hosts != nil { - if flc, ok := fr.Hosts[c.Details.Name]; ok { + if flc, ok := fr.Hosts[c.Certificate.Name()]; ok { if flc.match(p, c) { return true } } } - matched := false - prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) - fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { - if prefix.Contains(p.RemoteIP) && val.match(p, c) { - matched = true - return false + for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) { + if v.match(p, c) { + return true } - return true - }) - return matched + } + + return false } func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { if !localIp.IsValid() { - if !f.hasSubnets || f.defaultLocalCIDRAny { + if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { flc.Any = true return nil } - localIp = f.assignedCIDR + for _, network := range f.assignedNetworks { + flc.LocalCIDR.Insert(network, struct{}{}) + } + return nil + } else if localIp.Bits() == 0 { flc.Any = true + return nil } flc.LocalCIDR.Insert(localIp, struct{}{}) return nil } -func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool { +func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool { if flc == nil { return false } @@ -899,7 +901,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate return true } - _, ok := flc.LocalCIDR.Lookup(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalAddr) return ok } diff --git a/firewall/packet.go b/firewall/packet.go index 0cd2067..cd9c712 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -10,18 +10,19 @@ import ( type m map[string]interface{} const ( - ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever - ProtoTCP = 6 - ProtoUDP = 17 - ProtoICMP = 1 + ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever + ProtoTCP = 6 + ProtoUDP = 17 + ProtoICMP = 1 + ProtoICMPv6 = 58 PortAny = 0 // Special value for matching `port: any` PortFragment = -1 // Special value for matching `port: fragment` ) type Packet struct { - LocalIP netip.Addr - RemoteIP netip.Addr + LocalAddr netip.Addr + RemoteAddr netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8 @@ -30,8 +31,8 @@ type Packet struct { func (fp *Packet) Copy() *Packet { return &Packet{ - LocalIP: fp.LocalIP, - RemoteIP: fp.RemoteIP, + LocalAddr: fp.LocalAddr, + RemoteAddr: fp.RemoteAddr, LocalPort: fp.LocalPort, RemotePort: fp.RemotePort, Protocol: fp.Protocol, @@ -52,8 +53,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) { proto = fmt.Sprintf("unknown %v", fp.Protocol) } return json.Marshal(m{ - "LocalIP": fp.LocalIP.String(), - "RemoteIP": fp.RemoteIP.String(), + "LocalAddr": fp.LocalAddr.String(), + "RemoteAddr": fp.RemoteAddr.String(), "LocalPort": fp.LocalPort, "RemotePort": fp.RemotePort, "Protocol": proto, diff --git a/firewall_test.go b/firewall_test.go index 4d47e78..8d32369 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "math" - "net" "net/netip" "testing" "time" @@ -14,11 +13,12 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewFirewall(t *testing.T) { l := test.NewLogger() - c := &cert.NebulaCertificate{} + c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack assert.NotNil(t, conntrack) @@ -60,7 +60,7 @@ func TestFirewall_AddRule(t *testing.T) { ob := &bytes.Buffer{} l.SetOutput(ob) - c := &cert.NebulaCertificate{} + c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) @@ -129,35 +129,30 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Groups: []string{"default-group"}, - InvertedGroups: map[string]struct{}{"default-group": {}}, - Issuer: "signer-shasum", - }, + c := dummyCert{ + name: "host1", + networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", } h := HostInfo{ ConnectionState: &ConnectionState{ - peerCert: &c, + peerCert: &cert.CachedCertificate{ + Certificate: &c, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, }, - vpnIp: netip.MustParseAddr("1.2.3.4"), + vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -172,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) { assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch - oldRemote := p.RemoteIP - p.RemoteIP = netip.MustParseAddr("1.2.3.10") + oldRemote := p.RemoteAddr + p.RemoteAddr = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) - p.RemoteIP = oldRemote + p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) @@ -190,14 +185,14 @@ func TestFirewall_Drop(t *testing.T) { assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks - cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match - cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) @@ -217,7 +212,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("fail on proto", func(b *testing.B) { // This benchmark is showing us the cost of failing to match the protocol - c := &cert.NebulaCertificate{} + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)) } @@ -225,28 +222,31 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("pass proto, fail on port", func(b *testing.B) { // This benchmark is showing us the cost of matching a specific protocol but failing to match the port - c := &cert.NebulaCertificate{} + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)) } }) b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { - c := &cert.NebulaCertificate{} + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { - _, ip, _ := net.ParseCIDR("9.254.254.254/32") - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "nope", - Ips: []*net.IPNet{ip}, + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")}, }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) @@ -254,25 +254,24 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { - _, ip, _ := net.ParseCIDR("9.254.254.254/32") - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "nope", - Ips: []*net.IPNet{ip}, + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")}, }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) b.Run("pass on group on any local cidr", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"good-group": {}}, - Name: "nope", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", }, + InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) @@ -280,82 +279,28 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass on group on specific local cidr", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"good-group": {}}, - Name: "nope", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", }, + InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) b.Run("pass on name", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "good-host", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "good-host", }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) } }) - // - //b.Run("pass on ip", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) - // } - //}) - // - //b.Run("pass on local ip", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp) - // } - //}) - // - //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "") - // - //b.Run("pass on ip with any port", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) - // } - //}) - // - //b.Run("pass on local ip with any port", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp) - // } - //}) } func TestFirewall_Drop2(t *testing.T) { @@ -364,49 +309,47 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } + network := netip.MustParsePrefix("1.2.3.4/24") - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, }, + InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) - c1 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, + c1 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, }, + InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, } h1 := HostInfo{ + vpnAddrs: []netip.Addr{network.Addr()}, ConnectionState: &ConnectionState{ peerCert: &c1, }, } - h1.CreateRemoteCIDR(&c1) + h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() @@ -423,72 +366,68 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host-owner", - Ips: []*net.IPNet{&ipNet}, + network := netip.MustParsePrefix("1.2.3.4/24") + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host-owner", + networks: []netip.Prefix{network}, }, } - c1 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha-bad", + c1 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + issuer: "signer-sha-bad", }, } h1 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnAddrs: []netip.Addr{network.Addr()}, } - h1.CreateRemoteCIDR(&c1) + h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) - c2 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host2", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha", + c2 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host2", + networks: []netip.Prefix{network}, + issuer: "signer-sha", }, } h2 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnAddrs: []netip.Addr{network.Addr()}, } - h2.CreateRemoteCIDR(&c2) + h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) - c3 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host3", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha-bad", + c3 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host3", + networks: []netip.Prefix{network}, + issuer: "signer-sha-bad", }, } h3 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnAddrs: []netip.Addr{network.Addr()}, } - h3.CreateRemoteCIDR(&c3) + h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() @@ -501,6 +440,11 @@ func TestFirewall_Drop3(t *testing.T) { // c3 should fail because no match resetConntrack(fw) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) + + // Test a remote address match + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -509,37 +453,33 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } + network := netip.MustParsePrefix("1.2.3.4/24") - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Groups: []string{"default-group"}, - InvertedGroups: map[string]struct{}{"default-group": {}}, - Issuer: "signer-shasum", + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", }, + InvertedGroups: map[string]struct{}{"default-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() @@ -552,7 +492,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -561,7 +501,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -641,8 +581,6 @@ func BenchmarkLookup(b *testing.B) { ml(m, a) } }) - - //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster } func Test_parsePort(t *testing.T) { @@ -688,56 +626,59 @@ func Test_parsePort(t *testing.T) { func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition - c := &cert.NebulaCertificate{} + c := &dummyCert{} + cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + require.NoError(t, err) + conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} - _, err := NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } diff --git a/go.mod b/go.mod index adb2e84..7bd4925 100644 --- a/go.mod +++ b/go.mod @@ -1,52 +1,54 @@ module github.com/slackhq/nebula -go 1.22.0 +go 1.23.6 -toolchain go1.22.2 +toolchain go1.23.7 require ( - dario.cat/mergo v1.0.0 + dario.cat/mergo v1.0.1 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.11.1 + github.com/gaissmai/bart v0.18.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.61 + github.com/miekg/dns v1.1.62 + github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.19.1 + github.com/prometheus/client_golang v1.20.4 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 + github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.9.0 - github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.26.0 + github.com/vishvananda/netlink v1.3.0 + golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.28.0 - golang.org/x/sync v0.8.0 - golang.org/x/sys v0.24.0 - golang.org/x/term v0.23.0 + golang.org/x/net v0.37.0 + golang.org/x/sync v0.12.0 + golang.org/x/sys v0.31.0 + golang.org/x/term v0.30.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.34.2 + google.golang.org/protobuf v1.36.5 gopkg.in/yaml.v2 v2.4.0 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) require ( github.com/beorn7/perks v1.0.1 // indirect - github.com/bits-and-blooms/bitset v1.13.0 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect + github.com/klauspost/compress v1.17.9 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.5.0 // indirect - github.com/prometheus/common v0.48.0 // indirect - github.com/prometheus/procfs v0.12.0 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.55.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/mod v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/go.sum b/go.sum index 3afd6cb..2813b5f 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= -dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= +dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -14,11 +14,9 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= -github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -26,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= -github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= +github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ= +github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -70,6 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -80,13 +80,19 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs= -github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ= +github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= +github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg= @@ -100,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= -github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= +github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= +github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= -github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= +github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= @@ -129,8 +135,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= +github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= +github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= @@ -139,9 +145,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= -github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -151,8 +156,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -171,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -180,30 +185,30 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= -golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -234,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= +google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/handshake_ix.go b/handshake_ix.go index 0d54b01..7dffc62 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,10 +2,12 @@ package nebula import ( "net/netip" + "slices" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" ) @@ -17,23 +19,59 @@ import ( func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return false } - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) + // If we're connecting to a v6 address we must use a v2 cert + cs := f.pki.getCertState() + v := cs.defaultVersion + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } + } + + crt := cs.getCertificate(v) + if crt == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate is available") + return false + } + + crtHs := cs.getHandshakeBytes(v) + if crtHs == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate handshake bytes is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Failed to create connection state") + return false + } hh.hostinfo.ConnectionState = ci - hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: hh.hostinfo.localIndexId, - Time: uint64(time.Now().UnixNano()), - Cert: certState.RawCertificateNoKey, + hs := &NebulaHandshake{ + Details: &NebulaHandshakeDetails{ + InitiatorIndex: hh.hostinfo.localIndexId, + Time: uint64(time.Now().UnixNano()), + Cert: crtHs, + CertVersion: uint32(v), + }, } if f.multiPort.Tx || f.multiPort.Rx { - hsProto.InitiatorMultiPort = &MultiPortDetails{ + hs.Details.InitiatorMultiPort = &MultiPortDetails{ RxSupported: f.multiPort.Rx, TxSupported: f.multiPort.Tx, BasePort: uint32(f.multiPort.TxBasePort), @@ -41,15 +79,9 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { } } - hsBytes := []byte{} - - hs := &NebulaHandshake{ - Details: hsProto, - } - hsBytes, err = hs.Marshal() - + hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -58,7 +90,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return false } @@ -73,30 +105,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { } func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) + cs := f.pki.getCertState() + crt := cs.GetDefaultCertificate() + if crt == nil { + f.l.WithField("udpAddr", addr). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", cs.defaultVersion). + Error("Unable to handshake with host because no certificate is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to create connection state") + return + } + // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to call noise.ReadMessage") return } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) - /* - l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) - */ if err != nil || hs.Details == nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed unmarshal handshake message") return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -109,8 +155,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) - if !ok { + if remoteCert.Certificate.Version() != ci.myCert.Version() { + // We started off using the wrong certificate version, lets see if we can match the version that was sent to us + rc := cs.getCertificate(remoteCert.Certificate.Version()) + if rc == nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). + Info("Unable to handshake with host due to missing certificate version") + return + } + + // Record the certificate we are actually using + ci.myCert = rc + } + + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -122,30 +181,54 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp = vpnIp.Unmap() - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - issuer := remoteCert.Details.Issuer + var vpnAddrs []netip.Addr + var filteredNetworks []netip.Prefix + certName := remoteCert.Certificate.Name() + fingerprint := remoteCert.Fingerprint + issuer := remoteCert.Certificate.Issuer() - if vpnIp == f.myVpnNet.Addr() { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + for _, network := range remoteCert.Certificate.Networks() { + vpnAddr := network.Addr() + _, found := f.myVpnAddrsTable.Lookup(vpnAddr) + if found { + f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + return + } + + // vpnAddrs outside our vpn networks are of no use to us, filter them out + if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + continue + } + + filteredNetworks = append(filteredNetworks, network) + vpnAddrs = append(vpnAddrs, vpnAddr) + } + + if len(vpnAddrs) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") return } if addr.IsValid() { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + // addr can be invalid when the tunnel is being relayed. + // We only want to apply the remote allow list for direct tunnels here + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -177,19 +260,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ConnectionState: ci, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, - vpnIp: vpnIp, + vpnAddrs: vpnAddrs, HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, multiportTx: multiportTx, multiportRx: multiportRx, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -199,13 +282,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = certState.RawCertificateNoKey + hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) + if hs.Details.Cert == nil { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("certVersion", ci.myCert.Version()). + Error("Unable to handshake with host because no certificate handshake bytes is available") + return + } + + hs.Details.CertVersion = uint32(ci.myCert.Version()) // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -216,14 +312,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -247,9 +343,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.CreateRemoteCIDR(remoteCert) + hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -263,7 +359,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } msg = existing.HandshakePacket[2] @@ -278,11 +374,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet err = f.outside.WriteTo(msg, addr) } if err != nil { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } @@ -292,16 +388,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). @@ -312,23 +408,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). + WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). Error("Failed to add HostInfo due to localIndex collision") return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -351,7 +447,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet err = f.outside.WriteTo(msg, addr) } if err != nil { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -359,7 +455,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake") } else { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -372,9 +468,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + // I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure + // it's correctly marked as working. + via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -401,8 +500,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo := hh.hostinfo if addr.IsValid() { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } @@ -410,7 +510,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -419,7 +519,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") @@ -431,7 +531,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again @@ -452,9 +552,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ) } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { - e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) if f.l.Level > logrus.DebugLevel { @@ -467,8 +567,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) - if !ok { + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) @@ -476,65 +575,14 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha e = e.WithField("cert", remoteCert) } - e.Info("Invalid vpn ip from host") + e.Info("Empty networks from host") return true } - vpnIp = vpnIp.Unmap() - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - issuer := remoteCert.Details.Issuer - - // Ensure the right host responded - if vpnIp != hostinfo.vpnIp { - f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). - WithField("udpAddr", addr).WithField("certName", certName). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Incorrect host responded to handshake") - - // Release our old handshake from pending, it should not continue - f.handshakeManager.DeleteHostInfo(hostinfo) - - // Create a new hostinfo/handshake for the intended vpn ip - f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { - //TODO: this doesnt know if its being added or is being used for caching a packet - // Block the current used address - newHH.hostinfo.remotes = hostinfo.remotes - newHH.hostinfo.remotes.BlockRemote(addr) - - // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - - f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). - Info("Blocked addresses for handshakes") - - // Swap the packet store to benefit the original intended recipient - newHH.packetStore = hh.packetStore - hh.packetStore = []*cachedPacket{} - - // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnIp = vpnIp - f.sendCloseTunnel(hostinfo) - }) - - return true - } - - // Mark packet 2 as seen so it doesn't show up as missed - ci.window.Update(f.l, 2) - - duration := time.Since(hh.startTime).Nanoseconds() - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("durationNs", duration). - WithField("sentCachedPackets", len(hh.packetStore)). - WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx). - Info("Handshake message received") + vpnNetworks := remoteCert.Certificate.Networks() + certName := remoteCert.Certificate.Name() + fingerprint := remoteCert.Fingerprint + issuer := remoteCert.Certificate.Issuer() hostinfo.remoteIndexId = hs.Details.ResponderIndex hostinfo.lastHandshakeTime = hs.Details.Time @@ -548,13 +596,84 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha if addr.IsValid() { hostinfo.SetRemote(addr) } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } - // Build up the radix for the firewall if we have subnets in the cert - hostinfo.CreateRemoteCIDR(remoteCert) + var vpnAddrs []netip.Addr + var filteredNetworks []netip.Prefix + for _, network := range vpnNetworks { + // vpnAddrs outside our vpn networks are of no use to us, filter them out + vpnAddr := network.Addr() + if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + continue + } - // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp + filteredNetworks = append(filteredNetworks, network) + vpnAddrs = append(vpnAddrs, vpnAddr) + } + + if len(vpnAddrs) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") + return true + } + + // Ensure the right host responded + if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { + f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). + WithField("udpAddr", addr).WithField("certName", certName). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Incorrect host responded to handshake") + + // Release our old handshake from pending, it should not continue + f.handshakeManager.DeleteHostInfo(hostinfo) + + // Create a new hostinfo/handshake for the intended vpn ip + f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { + // Block the current used address + newHH.hostinfo.remotes = hostinfo.remotes + newHH.hostinfo.remotes.BlockRemote(addr) + + f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). + WithField("vpnNetworks", vpnNetworks). + WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). + Info("Blocked addresses for handshakes") + + // Swap the packet store to benefit the original intended recipient + newHH.packetStore = hh.packetStore + hh.packetStore = []*cachedPacket{} + + // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down + hostinfo.vpnAddrs = vpnAddrs + f.sendCloseTunnel(hostinfo) + }) + + return true + } + + // Mark packet 2 as seen so it doesn't show up as missed + ci.window.Update(f.l, 2) + + duration := time.Since(hh.startTime).Nanoseconds() + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + WithField("durationNs", duration). + WithField("sentCachedPackets", len(hh.packetStore)). + WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx). + Info("Handshake message received") + + // Build up the radix for the firewall if we have subnets in the cert + hostinfo.vpnAddrs = vpnAddrs + hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + + // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) diff --git a/handshake_manager.go b/handshake_manager.go index ce8af3a..0e406e7 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -7,14 +7,15 @@ import ( "encoding/binary" "errors" "net/netip" + "slices" "sync" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" - "golang.org/x/exp/slices" ) const ( @@ -121,18 +122,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig } } -func (c *HandshakeManager) Run(ctx context.Context) { - clockSource := time.NewTicker(c.config.tryInterval) +func (hm *HandshakeManager) Run(ctx context.Context) { + clockSource := time.NewTicker(hm.config.tryInterval) defer clockSource.Stop() for { select { case <-ctx.Done(): return - case vpnIP := <-c.trigger: - c.handleOutbound(vpnIP, true) + case vpnIP := <-hm.trigger: + hm.handleOutbound(vpnIP, true) case now := <-clockSource.C: - c.NextOutboundHandshakeTimerTick(now) + hm.NextOutboundHandshakeTimerTick(now) } } } @@ -140,7 +141,7 @@ func (c *HandshakeManager) Run(ctx context.Context) { func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp if addr.IsValid() { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) { hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -162,14 +163,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { - c.OutboundHandshakeTimer.Advance(now) +func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { + hm.OutboundHandshakeTimer.Advance(now) for { - vpnIp, has := c.OutboundHandshakeTimer.Purge() + vpnIp, has := hm.OutboundHandshakeTimer.Purge() if !has { break } - c.handleOutbound(vpnIp, false) + hm.handleOutbound(vpnIp, false) } } @@ -211,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp}) } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) @@ -226,7 +227,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hh.lastRemotes = remotes - // TODO: this will generate a load of queries for hosts with only 1 ip + // This will generate a load of queries for hosts with only 1 ip // (such as ones registered to the lighthouse with only a private IP) // So we only do it one time after attempting 5 handshakes already. if len(remotes) <= 1 && hh.counter == 5 { @@ -293,59 +294,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { - // Don't relay to myself, and don't relay through the host I'm trying to connect to - if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { + // Don't relay to myself + if relay == vpnIp { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(relay) + + // Don't relay through the host I'm trying to connect to + _, found := hm.f.myVpnAddrsTable.Lookup(relay) + if found { + continue + } + + relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hm.f.Handshake(relay) continue } - // Check the relay HostInfo to see if we already established a relay through it - if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { - switch existingRelay.State { - case Established: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) - case Requested: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") - - //TODO: IPV6-WORK - myVpnIpB := hm.f.myVpnNet.Addr().As4() - theirVpnIpB := vpnIp.As4() - - // Re-send the CreateRelay request, in case the previous one was lost. - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), - RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), - } - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") - } else { - // This must send over the hostinfo, not over hm.Hosts[ip] - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnNet.Addr(), - "relayTo": vpnIp, - "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": relay}). - Info("send CreateRelayRequest") - } - default: - hostinfo.logger(hm.l). - WithField("vpnIp", vpnIp). - WithField("state", existingRelay.State). - WithField("relay", relayHostInfo.vpnIp). - Errorf("Relay unexpected state") - } - } else { + // Check the relay HostInfo to see if we already established a relay through + existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) + if !ok { // No relays exist or requested yet. if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) @@ -353,16 +321,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } - //TODO: IPV6-WORK - myVpnIpB := hm.f.myVpnNet.Addr().As4() - theirVpnIpB := vpnIp.As4() - m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), - RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() if err != nil { hostinfo.logger(hm.l). @@ -371,13 +358,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnNet.Addr(), + "relayFrom": hm.f.myVpnAddrs[0], "relayTo": vpnIp, "initiatorRelayIndex": idx, "relay": relay}). Info("send CreateRelayRequest") } } + continue + } + + switch existingRelay.State { + case Established: + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + case Disestablished: + // Mark this relay as 'requested' + relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) + fallthrough + case Requested: + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + // Re-send the CreateRelay request, in case the previous one was lost. + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: existingRelay.LocalIndex, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(hm.l). + WithError(err). + Error("Failed to marshal Control message to create relay") + } else { + // This must send over the hostinfo, not over hm.Hosts[ip] + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.f.myVpnAddrs[0], + "relayTo": vpnIp, + "initiatorRelayIndex": existingRelay.LocalIndex, + "relay": relay}). + Info("send CreateRelayRequest") + } + case PeerRequested: + // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. + fallthrough + default: + hostinfo.logger(hm.l). + WithField("vpnIp", vpnIp). + WithField("state", existingRelay.State). + WithField("relay", relay). + Errorf("Relay unexpected state") + } } } @@ -407,10 +461,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() - if hh, ok := hm.vpnIps[vpnIp]; ok { + if hh, ok := hm.vpnIps[vpnAddr]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { cacheCb(hh) @@ -420,12 +474,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands } hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } @@ -433,9 +487,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands hostinfo: hostinfo, startTime: time.Now(), } - hm.vpnIps[vpnIp] = hh + hm.vpnIps[vpnAddr] = hh hm.metricInitiated.Inc(1) - hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) + hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval) if cacheCb != nil { cacheCb(hh) @@ -443,21 +497,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now - _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] + _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr] if !doTrigger { // Add any calculated remotes, and trigger early handshake if one found - doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) + doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr) } if doTrigger { select { - case hm.trigger <- vpnIp: + case hm.trigger <- vpnAddr: default: } } hm.Unlock() - hm.lightHouse.QueryServer(vpnIp) + hm.lightHouse.QueryServer(vpnAddr) return hostinfo } @@ -478,14 +532,14 @@ var ( // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. -func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { - c.mainHostMap.Lock() - defer c.mainHostMap.Unlock() - c.Lock() - defer c.Unlock() +func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { + hm.mainHostMap.Lock() + defer hm.mainHostMap.Unlock() + hm.Lock() + defer hm.Unlock() // Check if we already have a tunnel with this vpn ip - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] + existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]] if found && existingHostInfo != nil { testHostInfo := existingHostInfo for testHostInfo != nil { @@ -502,31 +556,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingHostInfo, ErrExistingHostInfo } - existingHostInfo.logger(c.l).Info("Taking new handshake") + existingHostInfo.logger(hm.l).Info("Taking new handshake") } - existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] + existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId] if found { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } - existingPendingIndex, found := c.indexes[hostinfo.localIndexId] + existingPendingIndex, found := hm.indexes[hostinfo.localIndexId] if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo return existingPendingIndex.hostinfo, ErrLocalIndexCollision } - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] - if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp { + existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(c.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + hostinfo.logger(hm.l). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } - c.mainHostMap.unlockedAddHostInfo(hostinfo, f) + hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) return existingHostInfo, nil } @@ -544,7 +598,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } @@ -581,31 +635,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { return errors.New("failed to generate unique localIndexId") } -func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { - c.Lock() - defer c.Unlock() - c.unlockedDeleteHostInfo(hostinfo) +func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { + hm.Lock() + defer hm.Unlock() + hm.unlockedDeleteHostInfo(hostinfo) } -func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { - delete(c.vpnIps, hostinfo.vpnIp) - if len(c.vpnIps) == 0 { - c.vpnIps = map[netip.Addr]*HandshakeHostInfo{} +func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { + for _, addr := range hostinfo.vpnAddrs { + delete(hm.vpnIps, addr) } - delete(c.indexes, hostinfo.localIndexId) - if len(c.vpnIps) == 0 { - c.indexes = map[uint32]*HandshakeHostInfo{} + if len(hm.vpnIps) == 0 { + hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } - if c.l.Level >= logrus.DebugLevel { - c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + delete(hm.indexes, hostinfo.localIndexId) + if len(hm.indexes) == 0 { + hm.indexes = map[uint32]*HandshakeHostInfo{} + } + + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Pending hostmap hostInfo deleted") } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { +func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -634,37 +691,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { - return c.mainHostMap.GetPreferredRanges() +func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix { + return hm.mainHostMap.GetPreferredRanges() } -func (c *HandshakeManager) ForEachVpnIp(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.vpnIps { + for _, v := range hm.vpnIps { f(v.hostinfo) } } -func (c *HandshakeManager) ForEachIndex(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachIndex(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.indexes { + for _, v := range hm.indexes { f(v.hostinfo) } } -func (c *HandshakeManager) EmitStats() { - c.RLock() - hostLen := len(c.vpnIps) - indexLen := len(c.indexes) - c.RUnlock() +func (hm *HandshakeManager) EmitStats() { + hm.RLock() + hostLen := len(hm.vpnIps) + indexLen := len(hm.indexes) + hm.RUnlock() metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) - c.mainHostMap.EmitStats() + hm.mainHostMap.EmitStats() } // Utility functions below diff --git a/handshake_manager_test.go b/handshake_manager_test.go index a78b45f..7edc55b 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -14,21 +14,20 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") ip := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} - mainHM := newHostMap(l, vpncidr) + mainHM := newHostMap(l) mainHM.preferredRanges.Store(&preferredRanges) lh := newTestLighthouse() cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -42,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) - i.remotes = NewRemoteList(nil) + i.remotes = NewRemoteList([]netip.Addr{}, nil) // Adding something to pending should not affect the main hostmap assert.Len(t, mainHM.Hosts, 0) @@ -80,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { return } -func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { +func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { return } -func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { return } -func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} +func (mw *mockEncWriter) Handshake(_ netip.Addr) {} + +func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { + return nil +} + +func (mw *mockEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: cert.Version2} +} diff --git a/hostmap.go b/hostmap.go index 40031a3..7b2f2b1 100644 --- a/hostmap.go +++ b/hostmap.go @@ -35,6 +35,7 @@ const ( Requested = iota PeerRequested Established + Disestablished ) const ( @@ -48,7 +49,7 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp netip.Addr + PeerAddr netip.Addr } type HostMap struct { @@ -58,7 +59,6 @@ type HostMap struct { RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] - vpnCIDR netip.Prefix l *logrus.Logger } @@ -68,9 +68,12 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer + // For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data, + // modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with + // the RelayState Lock held) + relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } func (rs *RelayState) DeleteRelay(ip netip.Addr) { @@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) { delete(rs.relays, ip) } +func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) { + rs.Lock() + defer rs.Unlock() + if r, ok := rs.relayForByAddr[vpnIp]; ok { + newRelay := *r + newRelay.State = state + rs.relayForByAddr[newRelay.PeerAddr] = &newRelay + rs.relayForByIdx[newRelay.LocalIndex] = &newRelay + } +} + +func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) { + rs.Lock() + defer rs.Unlock() + if r, ok := rs.relayForByIdx[idx]; ok { + newRelay := *r + newRelay.State = state + rs.relayForByAddr[newRelay.PeerAddr] = &newRelay + rs.relayForByIdx[newRelay.LocalIndex] = &newRelay + } +} + func (rs *RelayState) CopyAllRelayFor() []*Relay { rs.RLock() defer rs.RUnlock() @@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { +func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[ip] + r, ok := rs.relayForByAddr[addr] return r, ok } @@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr { func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) - for relayIp := range rs.relayForByIp { + currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr)) + for relayIp := range rs.relayForByAddr { currentRelays = append(currentRelays, relayIp) } return currentRelays @@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] if !ok { return false } @@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return true } @@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return &newRelay, true } func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] return r, ok } @@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() - rs.relayForByIp[ip] = r + rs.relayForByAddr[ip] = r rs.relayForByIdx[idx] = r } @@ -190,10 +215,16 @@ type HostInfo struct { ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp netip.Addr - recvError atomic.Uint32 - remoteCidr *bart.Table[struct{}] - relayState RelayState + + // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks + // The host may have other vpn addresses that are outside our + // vpn networks but were removed because they are not usable + vpnAddrs []netip.Addr + recvError atomic.Uint32 + + // networks are both all vpn and unsafe networks assigned to this host + networks *bart.Table[struct{}] + relayState RelayState // If true, we should send to this remote using multiport multiportTx bool @@ -247,28 +278,26 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { - hm := newHostMap(l, vpnCIDR) +func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { + hm := newHostMap(l) hm.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { hm.reload(c, false) }) - l.WithField("network", hm.vpnCIDR.String()). - WithField("preferredRanges", hm.GetPreferredRanges()). + l.WithField("preferredRanges", hm.GetPreferredRanges()). Info("Main HostMap created") return hm } -func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { +func newHostMap(l *logrus.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, Hosts: map[netip.Addr]*HostInfo{}, - vpnCIDR: vpnCIDR, l: l, } } @@ -311,17 +340,6 @@ func (hm *HostMap) EmitStats() { metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } -func (hm *HostMap) RemoveRelay(localIdx uint32) { - hm.Lock() - _, ok := hm.Relays[localIdx] - if !ok { - hm.Unlock() - return - } - delete(hm.Relays, localIdx) - hm.Unlock() -} - // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore @@ -341,48 +359,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { } func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { - oldHostinfo := hm.Hosts[hostinfo.vpnIp] + // Get the current primary, if it exists + oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]] + + // Every address in the hostinfo gets elevated to primary + for _, vpnAddr := range hostinfo.vpnAddrs { + //NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on + // indexes so it should be fine. + hm.Hosts[vpnAddr] = hostinfo + } + + // If we are already primary then we won't bother re-linking if oldHostinfo == hostinfo { return } + // Unlink this hostinfo if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } - if hostinfo.next != nil { hostinfo.next.prev = hostinfo.prev } - hm.Hosts[hostinfo.vpnIp] = hostinfo - + // If there wasn't a previous primary then clear out any links if oldHostinfo == nil { + hostinfo.next = nil + hostinfo.prev = nil return } + // Relink the hostinfo as primary hostinfo.next = oldHostinfo oldHostinfo.prev = hostinfo hostinfo.prev = nil } func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { - primary, ok := hm.Hosts[hostinfo.vpnIp] + for _, addr := range hostinfo.vpnAddrs { + h := hm.Hosts[addr] + for h != nil { + if h == hostinfo { + hm.unlockedInnerDeleteHostInfo(h, addr) + } + h = h.next + } + } +} + +func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) { + primary, ok := hm.Hosts[addr] + isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil if ok && primary == hostinfo { - // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it - delete(hm.Hosts, hostinfo.vpnIp) + // The vpn addr pointer points to the same hostinfo as the local index id, we can remove it + delete(hm.Hosts, addr) if len(hm.Hosts) == 0 { hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { - // We had more than 1 hostinfo at this vpnip, promote the next in the list to primary - hm.Hosts[hostinfo.vpnIp] = hostinfo.next + // We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary + hm.Hosts[addr] = hostinfo.next // It is primary, there is no previous hostinfo now hostinfo.next.prev = nil } } else { - // Relink if we were in the middle of multiple hostinfos for this vpn ip + // Relink if we were in the middle of multiple hostinfos for this vpn addr if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } @@ -412,10 +455,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } + if isLastHostinfo { + // I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next + // hops as 'Requested' so that new relay tunnels are created in the future. + hm.unlockedDisestablishVpnAddrRelayFor(hostinfo) + } + // Clean up any local relay indexes for which I am acting as a relay hop for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { delete(hm.Relays, localRelayIdx) } @@ -454,11 +503,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { - return hm.queryVpnIp(vpnIp, nil) +func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { + return hm.queryVpnAddr(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -466,17 +515,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn if !ok { return nil, nil, errors.New("unable to find host") } + for h != nil { - r, ok := h.relayState.QueryRelayForByIp(targetIp) - if ok && r.State == Established { - return h, r, nil + for _, targetIp := range targetIps { + r, ok := h.relayState.QueryRelayForByIp(targetIp) + if ok && r.State == Established { + return h, r, nil + } } h = h.next } + return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) { + for _, relayHostIp := range hi.relayState.CopyRelayIps() { + if h, ok := hm.Hosts[relayHostIp]; ok { + for h != nil { + h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) + h = h.next + } + } + } + for _, rs := range hi.relayState.CopyAllRelayFor() { + if rs.Type == ForwardingType { + if h, ok := hm.Hosts[rs.PeerAddr]; ok { + for h != nil { + h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) + h = h.next + } + } + } + } +} + +func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -497,25 +571,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) + dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } - - existing := hm.Hosts[hostinfo.vpnIp] - hm.Hosts[hostinfo.vpnIp] = hostinfo - - if existing != nil { - hostinfo.next = existing - existing.prev = hostinfo + for _, addr := range hostinfo.vpnAddrs { + hm.unlockedInnerAddHostInfo(addr, hostinfo, f) } hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). + hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). Debug("Hostmap vpnIp added") } +} + +func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { + existing := hm.Hosts[vpnAddr] + hm.Hosts[vpnAddr] = hostinfo + + if existing != nil && existing != hostinfo { + hostinfo.next = existing + existing.prev = hostinfo + } i := 1 check := hostinfo @@ -533,7 +612,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix { return *hm.preferredRanges.Load() } -func (hm *HostMap) ForEachVpnIp(f controlEach) { +func (hm *HostMap) ForEachVpnAddr(f controlEach) { hm.RLock() defer hm.RUnlock() @@ -587,11 +666,11 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac } i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) - ifce.lightHouse.QueryServer(i.vpnIp) + ifce.lightHouse.QueryServer(i.vpnAddrs[0]) } } -func (i *HostInfo) GetCert() *cert.NebulaCertificate { +func (i *HostInfo) GetCert() *cert.CachedCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert } @@ -602,7 +681,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object if i.remote != remote { i.remote = remote - i.remotes.LearnRemote(i.vpnIp, remote) + i.remotes.LearnRemote(i.vpnAddrs[0], remote) } } @@ -653,29 +732,20 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } -func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { - if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 { +func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { + if len(networks) == 1 && len(unsafeNetworks) == 0 { // Simple case, no CIDRTree needed return } - remoteCidr := new(bart.Table[struct{}]) - for _, ip := range c.Details.Ips { - //TODO: IPV6-WORK what to do when ip is invalid? - nip, _ := netip.AddrFromSlice(ip.IP) - nip = nip.Unmap() - bits, _ := ip.Mask.Size() - remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) + i.networks = new(bart.Table[struct{}]) + for _, network := range networks { + i.networks.Insert(network, struct{}{}) } - for _, n := range c.Details.Subnets { - //TODO: IPV6-WORK what to do when ip is invalid? - nip, _ := netip.AddrFromSlice(n.IP) - nip = nip.Unmap() - bits, _ := n.Mask.Size() - remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) + for _, network := range unsafeNetworks { + i.networks.Insert(network, struct{}{}) } - i.remoteCidr = remoteCidr } func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { @@ -683,13 +753,13 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { return logrus.NewEntry(l) } - li := l.WithField("vpnIp", i.vpnIp). + li := l.WithField("vpnAddrs", i.vpnAddrs). WithField("localIndex", i.localIndexId). WithField("remoteIndex", i.remoteIndexId) if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Details.Name) + li = li.WithField("certName", peerCert.Certificate.Name()) } } @@ -698,9 +768,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { +func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage - var ips []netip.Addr + var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) @@ -712,39 +782,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { continue } addrs, _ := i.Addrs() - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { + for _, rawAddr := range addrs { + var addr netip.Addr + switch v := rawAddr.(type) { case *net.IPNet: //continue - ip = v.IP + addr, _ = netip.AddrFromSlice(v.IP) case *net.IPAddr: - ip = v.IP + addr, _ = netip.AddrFromSlice(v.IP) } - nip, ok := netip.AddrFromSlice(ip) - if !ok { + if !addr.IsValid() { if l.Level >= logrus.DebugLevel { - l.WithField("localIp", ip).Debug("ip was invalid for netip") + l.WithField("localAddr", rawAddr).Debug("addr was invalid") } continue } - nip = nip.Unmap() + addr = addr.Unmap() - //TODO: Filtering out link local for now, this is probably the most correct thing - //TODO: Would be nice to filter out SLAAC MAC based ips as well - if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false { - allow := allowList.Allow(nip) + if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { + isAllowed := allowList.Allow(addr) if l.Level >= logrus.TraceLevel { - l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") + l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") } - if !allow { + if !isAllowed { continue } - ips = append(ips, nip) + finalAddrs = append(finalAddrs, addr) } } } - return ips + return finalAddrs } diff --git a/hostmap_test.go b/hostmap_test.go index 7e2feb8..e974340 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -11,17 +11,14 @@ import ( func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - netip.MustParsePrefix("10.0.0.1/24"), - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} - h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} - h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} - h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) @@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - netip.MustParsePrefix("10.0.0.1/24"), - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} - h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} - h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} - h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} - h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} - h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} + h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5} + h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } @@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - hm := NewHostMapFromConfig( - l, - netip.MustParsePrefix("10.0.0.1/24"), - c, - ) + hm := NewHostMapFromConfig(l, c) toS := func(ipn []netip.Prefix) []string { var s []string diff --git a/hostmap_tester.go b/hostmap_tester.go index b2d1d1b..fe40c53 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -9,8 +9,8 @@ import ( "net/netip" ) -func (i *HostInfo) GetVpnIp() netip.Addr { - return i.vpnIp +func (i *HostInfo) GetVpnAddrs() []netip.Addr { + return i.vpnAddrs } func (i *HostInfo) GetLocalIndex() uint32 { diff --git a/inside.go b/inside.go index 467f1f2..6b790f5 100644 --- a/inside.go +++ b/inside.go @@ -21,14 +21,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { - return + if f.dropLocalBroadcast { + _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr) + if found { + return + } } - if fwPacket.RemoteIP == f.myVpnNet.Addr() { + _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr) + if found { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which - // routes packets from the Nebula IP to the Nebula IP through the Nebula + // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) @@ -37,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } // Otherwise, drop. On linux, we should never see these packets - Linux - // routes packets from the nebula IP to the nebula IP through the loopback device. + // routes packets from the nebula addr to the nebula addr through the loopback device. return } // Ignore multicast packets - if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", fwPacket.RemoteIP). + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") } return } @@ -118,21 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q, nil) } -func (f *Interface) Handshake(vpnIp netip.Addr) { - f.getOrHandshake(vpnIp, nil) +func (f *Interface) Handshake(vpnAddr netip.Addr) { + f.getOrHandshake(vpnAddr, nil) } -// getOrHandshake returns nil if the vpnIp is not routable. +// getOrHandshake returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !f.myVpnNet.Contains(vpnIp) { - vpnIp = f.inside.RouteFor(vpnIp) - if !vpnIp.IsValid() { +func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + _, found := f.myVpnNetworksTable.Lookup(vpnAddr) + if !found { + vpnAddr = f.inside.RouteFor(vpnAddr) + if !vpnAddr.IsValid() { return nil, false } } - return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) + return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -157,16 +162,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil) } -// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { +// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr +func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { + hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", vpnIp). - Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") + f.l.WithField("vpnAddr", vpnAddr). + Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") } return } @@ -259,7 +264,6 @@ func (f *Interface) SendVia(via *HostInfo, func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, udpPortGetter udp.SendPortGetter) { if ci.eKey == nil { - //TODO: log warning return } @@ -303,14 +307,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType f.connectionManager.Out(hostinfo.localIndexId) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against - // all our IPs and enable a faster roaming. + // all our addrs and enable a faster roaming. if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.vpnIp) + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") } } @@ -354,7 +358,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } else { // Try to send via a relay for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") diff --git a/interface.go b/interface.go index 63abe8d..5c7048e 100644 --- a/interface.go +++ b/interface.go @@ -2,7 +2,6 @@ package nebula import ( "context" - "encoding/binary" "errors" "fmt" "io" @@ -12,6 +11,7 @@ import ( "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -28,7 +28,6 @@ type InterfaceConfig struct { Outside udp.Conn Inside overlay.Device pki *PKI - Cipher string Firewall *Firewall ServeDns bool HandshakeManager *HandshakeManager @@ -52,25 +51,27 @@ type InterfaceConfig struct { } type Interface struct { - hostMap *HostMap - outside udp.Conn - inside overlay.Device - pki *PKI - cipher string - firewall *Firewall - connectionManager *connectionManager - handshakeManager *HandshakeManager - serveDns bool - createTime time.Time - lightHouse *LightHouse - myBroadcastAddr netip.Addr - myVpnNet netip.Prefix - dropLocalBroadcast bool - dropMulticast bool - routines int - disconnectInvalid atomic.Bool - closed atomic.Bool - relayManager *relayManager + hostMap *HostMap + outside udp.Conn + inside overlay.Device + pki *PKI + firewall *Firewall + connectionManager *connectionManager + handshakeManager *HandshakeManager + serveDns bool + createTime time.Time + lightHouse *LightHouse + myBroadcastAddrsTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate + myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate + myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate + dropLocalBroadcast bool + dropMulticast bool + routines int + disconnectInvalid atomic.Bool + closed atomic.Bool + relayManager *relayManager tryPromoteEvery atomic.Uint32 reQueryEvery atomic.Uint32 @@ -114,9 +115,11 @@ type EncWriter interface { out []byte, nocopy bool, ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) + SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) - Handshake(vpnIp netip.Addr) + Handshake(vpnAddr netip.Addr) + GetHostInfo(vpnAddr netip.Addr) *HostInfo + GetCertState() *CertState } type sendRecvErrorConfig uint8 @@ -127,10 +130,10 @@ const ( sendRecvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { +func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool { switch s { case sendRecvErrorPrivate: - return ip.Addr().IsPrivate() + return endpoint.Addr().IsPrivate() case sendRecvErrorAlways: return true case sendRecvErrorNever: @@ -167,47 +170,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no firewall rules") } - certificate := c.pki.GetCertState().Certificate - - myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) - if !ok { - return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP) - } - - myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask) - if !ok { - return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask) - } - - myVpnAddr = myVpnAddr.Unmap() - myVpnMask = myVpnMask.Unmap() - - if myVpnAddr.BitLen() != myVpnMask.BitLen() { - return nil, fmt.Errorf("ip address and mask are different lengths in certificate") - } - - ones, _ := certificate.Details.Ips[0].Mask.Size() - myVpnNet := netip.PrefixFrom(myVpnAddr, ones) - + cs := c.pki.getCertState() ifce := &Interface{ - pki: c.pki, - hostMap: c.HostMap, - outside: c.Outside, - inside: c.Inside, - cipher: c.Cipher, - firewall: c.Firewall, - serveDns: c.ServeDns, - handshakeManager: c.HandshakeManager, - createTime: time.Now(), - lightHouse: c.lightHouse, - dropLocalBroadcast: c.DropLocalBroadcast, - dropMulticast: c.DropMulticast, - routines: c.routines, - version: c.version, - writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), - myVpnNet: myVpnNet, - relayManager: c.relayManager, + pki: c.pki, + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + routines: c.routines, + version: c.version, + writers: make([]udp.Conn, c.routines), + readers: make([]io.ReadWriteCloser, c.routines), + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + myVpnAddrs: cs.myVpnAddrs, + myVpnAddrsTable: cs.myVpnAddrsTable, + myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, + relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -221,12 +206,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - if myVpnAddr.Is4() { - addr := myVpnNet.Masked().Addr().As4() - binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) - ifce.myBroadcastAddr = netip.AddrFrom4(addr) - } - ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) @@ -247,7 +226,7 @@ func (f *Interface) activate() { f.l.WithError(err).Error("Failed to get udp listen address") } - f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()). + f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). WithField("build", f.version).WithField("udpAddr", addr). WithField("boringcrypto", boringEnabled()). Info("Nebula interface is active") @@ -290,16 +269,22 @@ func (f *Interface) listenOut(i int) { runtime.LockOSThread() var li udp.Conn - // TODO clean this up with a coherent interface for each outside connection if i > 0 { li = f.writers[i] } else { li = f.outside } + ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) + plaintext := make([]byte, udp.MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + }) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -356,7 +341,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -441,6 +426,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { var rawStats func() certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil) + certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) for { select { @@ -450,17 +437,37 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() + + certState := f.pki.getCertState() + defaultCrt := certState.GetDefaultCertificate() + certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) + certDefaultVersion.Update(int64(defaultCrt.Version())) + if f.udpRaw != nil { if rawStats == nil { rawStats = udp.NewRawStatsEmitter(f.udpRaw) } rawStats() } - certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) + + // Report the max certificate version we are capable of using + if certState.v2Cert != nil { + certMaxVersion.Update(int64(certState.v2Cert.Version())) + } else { + certMaxVersion.Update(int64(certState.v1Cert.Version())) + } } } } +func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return f.hostMap.QueryVpnAddr(vpnIp) +} + +func (f *Interface) GetCertState() *CertState { + return f.pki.getCertState() +} + func (f *Interface) Close() error { f.closed.Store(true) diff --git a/iputil/packet.go b/iputil/packet.go index 719e034..b18e524 100644 --- a/iputil/packet.go +++ b/iputil/packet.go @@ -6,8 +6,6 @@ import ( "golang.org/x/net/ipv4" ) -//TODO: IPV6-WORK can probably delete this - const ( // Need 96 bytes for the largest reject packet: // - 20 byte ipv4 header diff --git a/lighthouse.go b/lighthouse.go index 62f4065..ce37023 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/netip" + "slices" "strconv" "sync" "sync/atomic" @@ -15,28 +16,28 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) -//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake? -//TODO: nodes are roaming lighthouses, this is bad. How are they learning? - var ErrHostNotKnown = errors.New("host not known") type LightHouse struct { - //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time + //TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnNet netip.Prefix - punchConn udp.Conn - punchy *Punchy + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + punchConn udp.Conn + punchy *Punchy // Local cache of answers from light houses - // map of vpn Ip to answers + // map of vpn addr to answers addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host @@ -64,12 +65,12 @@ type LightHouse struct { advertiseAddrs atomic.Pointer[[]netip.AddrPort] - // IP's of relays that can be used by peers to access me + // Addr's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]netip.Addr] queryChan chan netip.Addr - calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -78,7 +79,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -95,15 +96,16 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, } h := LightHouse{ - ctx: ctx, - amLighthouse: amLighthouse, - myVpnNet: myVpnNet, - addrMap: make(map[netip.Addr]*RemoteList), - nebulaPort: nebulaPort, - punchConn: pc, - punchy: p, - queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), - l: l, + ctx: ctx, + amLighthouse: amLighthouse, + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + addrMap: make(map[netip.Addr]*RemoteList), + nebulaPort: nebulaPort, + punchConn: pc, + punchy: p, + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), + l: l, } lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) @@ -180,11 +182,11 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) + addrs, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) if err != nil { return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - if len(ips) == 0 { + if len(addrs) == 0 { return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) } @@ -197,15 +199,16 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { port = int(lh.nebulaPort) } - //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used - ip := ips[0].Unmap() - if lh.myVpnNet.Contains(ip) { + //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used + addr := addrs[0].Unmap() + _, found := lh.myVpnNetworksTable.Lookup(addr) + if found { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } - advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port))) + advAddrs = append(advAddrs, netip.AddrPortFrom(addr, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) @@ -238,7 +241,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.remoteAllowList.Store(ral) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed") } } @@ -251,7 +253,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.localAllowList.Store(lal) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.local_allow_list has changed") } } @@ -264,7 +265,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.calculatedRemotes.Store(cr) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.calculated_remotes has changed") } } @@ -275,8 +275,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { // Entries no longer present must have their (possible) background DNS goroutines stopped. if existingStaticList := lh.staticList.Load(); existingStaticList != nil { lh.RLock() - for staticVpnIp := range *existingStaticList { - if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil { + for staticVpnAddr := range *existingStaticList { + if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { am.hr.Cancel() } } @@ -291,7 +291,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.staticList.Store(&staticList) if !initial { - //TODO: we should remove any remote list entries for static hosts that were removed/modified? if c.HasChanged("static_host_map") { lh.l.Info("static_host_map has changed") } @@ -333,11 +332,11 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { case false: relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { - lh.l.WithField("relay", v).Info("Read relay from config") - configRIP, err := netip.ParseAddr(v) - //TODO: We could print the error here - if err == nil { + if err != nil { + lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") + } else { + lh.l.WithField("relay", v).Info("Read relay from config") relaysForMe = append(relaysForMe, configRIP) } } @@ -355,14 +354,16 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ } for i, host := range lhs { - ip, err := netip.ParseAddr(host) + addr, err := netip.ParseAddr(host) if err != nil { return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !lh.myVpnNet.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil) + + _, found := lh.myVpnNetworksTable.Lookup(addr) + if !found { + return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) } - lhMap[ip] = struct{}{} + lhMap[addr] = struct{}{} } if !lh.amLighthouse && len(lhMap) == 0 { @@ -370,9 +371,9 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ } staticList := lh.GetStaticHostList() - for lhIP, _ := range lhMap { - if _, ok := staticList[lhIP]; !ok { - return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhIP) + for lhAddr, _ := range lhMap { + if _, ok := staticList[lhAddr]; !ok { + return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) } } @@ -425,13 +426,14 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc i := 0 for k, v := range shm { - vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k)) + vpnAddr, err := netip.ParseAddr(fmt.Sprintf("%v", k)) if err != nil { return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !lh.myVpnNet.Contains(vpnIp) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil) + _, found := lh.myVpnNetworksTable.Lookup(vpnAddr) + if !found { + return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } vals, ok := v.([]interface{}) @@ -443,7 +445,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } - err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList) + err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnAddr, remoteAddrs, staticList) if err != nil { return err } @@ -453,12 +455,12 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return nil } -func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { - if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip) +func (lh *LightHouse) Query(vpnAddr netip.Addr) *RemoteList { + if !lh.IsLighthouseAddr(vpnAddr) { + lh.QueryServer(vpnAddr) } lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { lh.RUnlock() return v } @@ -467,18 +469,18 @@ func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip netip.Addr) { - // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses - if lh.amLighthouse || lh.IsLighthouseIP(ip) { +func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) { + // Don't put lighthouse addrs in the query channel because we can't query lighthouses about lighthouses + if lh.amLighthouse || lh.IsLighthouseAddr(vpnAddr) { return } - lh.queryChan <- ip + lh.queryChan <- vpnAddr } -func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { +func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddrs[0]]; ok { lh.RUnlock() return v } @@ -487,24 +489,27 @@ func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one - return lh.unlockedGetRemoteList(ip) + return lh.unlockedGetRemoteList(vpnAddrs) } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing -// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp +// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnAddr // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? - if v, ok := lh.addrMap[vpnIp]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { // Swap lh lock for remote list lock v.RLock() defer v.RUnlock() lh.RUnlock() - // vpnIp should also be the owner here since we are a lighthouse. - c := v.cache[vpnIp] + // We may be asking about a non primary address so lets get the primary address + if slices.Contains(v.vpnAddrs, vpnAddr) { + vpnAddr = v.vpnAddrs[0] + } + c := v.cache[vpnAddr] // Make sure we have if c != nil { n, err := f(c) @@ -516,112 +521,140 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) { +func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { // First we check the static mapping // and do nothing if it is there - if _, ok := lh.GetStaticHostList()[vpnIp]; ok { + if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { return } lh.Lock() - //l.Debugln(lh.addrMap) - delete(lh.addrMap, vpnIp) - - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", vpnIp) + rm, ok := lh.addrMap[allVpnAddrs[0]] + if ok { + for _, addr := range allVpnAddrs { + srm := lh.addrMap[addr] + if srm == rm { + delete(lh.addrMap, addr) + if lh.l.Level >= logrus.DebugLevel { + lh.l.Debugf("deleting %s from lighthouse.", addr) + } + } + } } - lh.Unlock() } -// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner +// AddStaticRemote adds a static host entry for vpnAddr as ourselves as the owner // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnAddr netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() ctx := lh.ctx lh.Unlock() hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() { - // This callback runs whenever the DNS hostname resolver finds a different set of IP's + // This callback runs whenever the DNS hostname resolver finds a different set of addr's // in its resolution for hostnames. am.Lock() defer am.Unlock() am.shouldRebuild = true }) if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + return util.NewContextualError("Static host address could not be parsed", m{"vpnAddr": vpnAddr, "entry": i + 1}, err) } am.unlockedSetHostnamesResults(hr) - for _, addrPort := range hr.GetIPs() { - if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + for _, addrPort := range hr.GetAddrs() { + if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { continue } switch { case addrPort.Addr().Is4(): - am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) + am.unlockedPrependV4(lh.myVpnNetworks[0].Addr(), netAddrToProtoV4AddrPort(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) + am.unlockedPrependV6(lh.myVpnNetworks[0].Addr(), netAddrToProtoV6AddrPort(addrPort.Addr(), addrPort.Port())) } } // Mark it as static in the caller provided map - staticList[vpnIp] = struct{}{} + staticList[vpnAddr] = struct{}{} return nil } // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } - calculatedRemotes, ok := tree.Lookup(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnAddr) if !ok { return false } - var calculated []*Ip4AndPort + var calculatedV4 []*V4AddrPort + var calculatedV6 []*V6AddrPort for _, cr := range calculatedRemotes { - c := cr.Apply(vpnIp) - if c != nil { - calculated = append(calculated, c) + if vpnAddr.Is4() { + c := cr.ApplyV4(vpnAddr) + if c != nil { + calculatedV4 = append(calculatedV4, c) + } + } else if vpnAddr.Is6() { + c := cr.ApplyV6(vpnAddr) + if c != nil { + calculatedV6 = append(calculatedV6, c) + } } } lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) + if len(calculatedV4) > 0 { + am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV4, lh.unlockedShouldAddV4) + } - return len(calculated) > 0 + if len(calculatedV6) > 0 { + am.unlockedSetV6(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV6, lh.unlockedShouldAddV6) + } + + return len(calculatedV4) > 0 || len(calculatedV6) > 0 } -// unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { - am, ok := lh.addrMap[vpnIp] +// unlockedGetRemoteList +// assumes you have the lh lock +func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { + am, ok := lh.addrMap[allAddrs[0]] if !ok { - am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) - lh.addrMap[vpnIp] = am + am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) + for _, addr := range allAddrs { + lh.addrMap[addr] = am + } } return am } -func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { - allow := lh.GetRemoteAllowList().Allow(vpnIp, to) +func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(to) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(to) + if found { return false } @@ -629,14 +662,20 @@ func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool { - ip := AddrPortFromIp4AndPort(to) - allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) +func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { + udpAddr := protoV4AddrPortToNetAddrPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(ip.Addr()) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) + if found { return false } @@ -644,78 +683,43 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool { - ip := AddrPortFromIp6AndPort(to) - allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) +func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { + udpAddr := protoV6AddrPortToNetAddrPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(ip.Addr()) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) + if found { return false } return true } -func lhIp6ToIp(v *Ip6AndPort) net.IP { - ip := make(net.IP, 16) - binary.BigEndian.PutUint64(ip[:8], v.Hi) - binary.BigEndian.PutUint64(ip[8:], v.Lo) - return ip -} - -func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { - if _, ok := lh.GetLighthouses()[vpnIp]; ok { +func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { + if _, ok := lh.GetLighthouses()[vpnAddr]; ok { return true } return false } -func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta { - if vpnIp.Is6() { - //TODO: need to support ipv6 - panic("ipv6 is not yet supported") - } - - b := vpnIp.As4() - return &NebulaMeta{ - Type: NebulaMeta_HostQuery, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(b[:]), - }, - } -} - -func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort { - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], ip.Ip) - return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port)) -} - -func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort { - b := [16]byte{} - binary.BigEndian.PutUint64(b[:8], ip.Hi) - binary.BigEndian.PutUint64(b[8:], ip.Lo) - return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port)) -} - -func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { - v4Addr := ip.As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), - Port: uint32(port), - } -} - -// TODO: IPV6-WORK we can delete some more of these -func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { - ip6Addr := ip.As16() - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip6Addr[:8]), - Lo: binary.BigEndian.Uint64(ip6Addr[8:]), - Port: uint32(port), +// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake +// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially +func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool { + l := lh.GetLighthouses() + for _, a := range vpnAddr { + if _, ok := l[a]; ok { + return true + } } + return false } func (lh *LightHouse) startQueryWorker() { @@ -731,31 +735,85 @@ func (lh *LightHouse) startQueryWorker() { select { case <-lh.ctx.Done(): return - case ip := <-lh.queryChan: - lh.innerQueryServer(ip, nb, out) + case addr := <-lh.queryChan: + lh.innerQueryServer(addr, nb, out) } } }() } -func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { - if lh.IsLighthouseIP(ip) { +func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { + if lh.IsLighthouseAddr(addr) { return } - // Send a query to the lighthouses and hope for the best next time - query, err := NewLhQueryByInt(ip).Marshal() - if err != nil { - lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") - return + msg := &NebulaMeta{ + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{}, } + var v1Query, v2Query []byte + var err error + var v cert.Version + queried := 0 lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) - for n := range lighthouses { - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) + for lhVpnAddr := range lighthouses { + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion + } + + if v == cert.Version1 { + if !addr.Is4() { + lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). + Error("Can't query lighthouse for v6 address using a v1 protocol") + continue + } + + if v1Query == nil { + b := addr.As4() + msg.Details.VpnAddr = nil + msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + + v1Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v1 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out) + queried++ + + } else if v == cert.Version2 { + if v2Query == nil { + msg.Details.OldVpnAddr = 0 + msg.Details.VpnAddr = netAddrToProtoAddr(addr) + + v2Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v2 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out) + queried++ + + } else { + lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + continue + } } + + lh.metricTx(NebulaMeta_HostQuery, int64(queried)) } func (lh *LightHouse) StartUpdateWorker() { @@ -785,65 +843,120 @@ func (lh *LightHouse) StartUpdateWorker() { } func (lh *LightHouse) SendUpdate() { - var v4 []*Ip4AndPort - var v6 []*Ip6AndPort + var v4 []*V4AddrPort + var v6 []*V6AddrPort for _, e := range lh.GetAdvertiseAddrs() { if e.Addr().Is4() { - v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port())) + v4 = append(v4, netAddrToProtoV4AddrPort(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port())) + v6 = append(v6, netAddrToProtoV6AddrPort(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() - for _, e := range localIps(lh.l, lal) { - if lh.myVpnNet.Contains(e) { + for _, e := range localAddrs(lh.l, lal) { + _, found := lh.myVpnNetworksTable.Lookup(e) + if found { continue } - // Only add IPs that aren't my VPN/tun IP + // Only add addrs that aren't my VPN/tun networks if e.Is4() { - v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort))) + v4 = append(v4, netAddrToProtoV4AddrPort(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort))) + v6 = append(v6, netAddrToProtoV6AddrPort(e, uint16(lh.nebulaPort))) } } - var relays []uint32 - for _, r := range lh.GetRelaysForMe() { - //TODO: IPV6-WORK both relays and vpnip need ipv6 support - b := r.As4() - relays = append(relays, binary.BigEndian.Uint32(b[:])) - } - - //TODO: IPV6-WORK both relays and vpnip need ipv6 support - b := lh.myVpnNet.Addr().As4() - - m := &NebulaMeta{ - Type: NebulaMeta_HostUpdateNotification, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(b[:]), - Ip4AndPorts: v4, - Ip6AndPorts: v6, - RelayVpnIp: relays, - }, - } - - lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lighthouses))) nb := make([]byte, 12, 12) out := make([]byte, mtu) - mm, err := m.Marshal() - if err != nil { - lh.l.WithError(err).Error("Error while marshaling for lighthouse update") - return + var v1Update, v2Update []byte + var err error + updated := 0 + lighthouses := lh.GetLighthouses() + + for lhVpnAddr := range lighthouses { + var v cert.Version + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion + } + if v == cert.Version1 { + if v1Update == nil { + if !lh.myVpnNetworks[0].Addr().Is4() { + lh.l.WithField("lighthouseAddr", lhVpnAddr). + Warn("cannot update lighthouse using v1 protocol without an IPv4 address") + continue + } + var relays []uint32 + for _, r := range lh.GetRelaysForMe() { + if !r.Is4() { + continue + } + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) + } + b := lh.myVpnNetworks[0].Addr().As4() + msg := NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + V4AddrPorts: v4, + V6AddrPorts: v6, + OldRelayVpnAddrs: relays, + OldVpnAddr: binary.BigEndian.Uint32(b[:]), + }, + } + + v1Update, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse v1 update") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) + updated++ + + } else if v == cert.Version2 { + if v2Update == nil { + var relays []*Addr + for _, r := range lh.GetRelaysForMe() { + relays = append(relays, netAddrToProtoAddr(r)) + } + + msg := NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + V4AddrPorts: v4, + V6AddrPorts: v6, + RelayVpnAddrs: relays, + VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()), + }, + } + + v2Update, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse v2 update") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) + updated++ + + } else { + lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) + continue + } } - for vpnIp := range lighthouses { - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) - } + lh.metricTx(NebulaMeta_HostUpdateNotification, int64(updated)) } type LightHouseHandler struct { @@ -886,34 +999,29 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { lhh.meta.Reset() // Keep the array memory around - details.Ip4AndPorts = details.Ip4AndPorts[:0] - details.Ip6AndPorts = details.Ip6AndPorts[:0] - details.RelayVpnIp = details.RelayVpnIp[:0] + details.V4AddrPorts = details.V4AddrPorts[:0] + details.V6AddrPorts = details.V6AddrPorts[:0] + details.RelayVpnAddrs = details.RelayVpnAddrs[:0] + details.OldRelayVpnAddrs = details.OldRelayVpnAddrs[:0] + details.OldVpnAddr = 0 + details.VpnAddr = nil lhh.meta.Details = details return lhh.meta } -func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) { - lhh.HandleRequest(rAddr, vpnIp, p, f) - } -} - -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") - //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") - //TODO: send recv_error? return } @@ -921,24 +1029,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Ad switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, vpnIp, rAddr, w) + lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, vpnIp) + lhh.handleHostQueryReply(n, fromVpnAddrs) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, vpnIp, w) + lhh.handleHostUpdateNotification(n, fromVpnAddrs, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, vpnIp, w) + lhh.handleHostPunchNotification(n, fromVpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -947,21 +1055,37 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a return } - //TODO: we can DRY this further - reqVpnIp := n.Details.VpnIp + useVersion := cert.Version1 + var queryVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnAddr = netip.AddrFrom4(b) + useVersion = 1 + } else if n.Details.VpnAddr != nil { + queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = 2 + } else { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery") + } + return + } - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - queryVpnIp := netip.AddrFrom4(b) - - //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnAddr, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply - n.Details.VpnIp = reqVpnIp + if useVersion == cert.Version1 { + if !queryVpnAddr.Is4() { + return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") + } + b := queryVpnAddr.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else { + n.Details.VpnAddr = netAddrToProtoAddr(queryVpnAddr) + } - lhh.coalesceAnswers(c, n) + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -971,21 +1095,51 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) - // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { + lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) +} + +// sendHostPunchNotification signals the other side to punch some zero byte udp packets +func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) { + whereToPunch := fromVpnAddrs[0] + found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - //TODO: IPV6-WORK - b = vpnIp.As4() - n.Details.VpnIp = binary.BigEndian.Uint32(b[:]) - lhh.coalesceAnswers(c, n) + targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest) + var useVersion cert.Version + if targetHI == nil { + useVersion = lhh.lh.ifce.GetCertState().defaultVersion + } else { + crt := targetHI.GetCert().Certificate + useVersion = crt.Version() + // we can only retarget if we have a hostinfo + newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs) + if ok { + whereToPunch = newDest + } else { + //TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee + //choosing to do nothing for now, but maybe we return an error? + } + } + + if useVersion == cert.Version1 { + if !whereToPunch.Is4() { + return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") + } + b := whereToPunch.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch) + } else { + return 0, errors.New("unsupported version") + } + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -995,139 +1149,169 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - - //TODO: IPV6-WORK - binary.BigEndian.PutUint32(b[:], reqVpnIp) - sendTo := netip.AddrFrom4(b) - w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { +func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { if c.v4 != nil { if c.v4.learned != nil { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) } if c.v4.reported != nil && len(c.v4.reported) > 0 { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) } } if c.v6 != nil { if c.v6.learned != nil { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) } if c.v6.reported != nil && len(c.v6.reported) > 0 { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) } } if c.relay != nil { - //TODO: IPV6-WORK - relays := make([]uint32, len(c.relay.relay)) - b := [4]byte{} - for i, _ := range relays { - b = c.relay.relay[i].As4() - relays[i] = binary.BigEndian.Uint32(b[:]) + if v == cert.Version1 { + b := [4]byte{} + for _, r := range c.relay.relay { + if !r.Is4() { + continue + } + + b = r.As4() + n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) + } + + } else if v == cert.Version2 { + for _, r := range c.relay.relay { + n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) + } + + } else { + //TODO: CERT-V2 don't panic + panic("unsupported version") } - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...) } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs []netip.Addr) { + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } lhh.lh.Lock() - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - certVpnIp := netip.AddrFrom4(b) - am := lhh.lh.unlockedGetRemoteList(certVpnIp) + + var certVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + certVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + relays := n.Details.GetRelays() + + am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr}) am.Lock() lhh.lh.Unlock() - //TODO: IPV6-WORK - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - - //TODO: IPV6-WORK - relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) - for i, _ := range n.Details.RelayVpnIp { - binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) - relays[i] = netip.AddrFrom4(b) - } - am.unlockedSetRelay(vpnIp, certVpnIp, relays) + am.unlockedSetV4(fromVpnAddrs[0], certVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], certVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- certVpnIp: + case lhh.lh.handshakeTrigger <- certVpnAddr: default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) + lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) } return } - //Simple check that the host sent this not someone else - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - detailsVpnIp := netip.AddrFrom4(b) - if detailsVpnIp != vpnIp { + var detailsVpnAddr netip.Addr + useVersion := cert.Version1 + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + detailsVpnAddr = netip.AddrFrom4(b) + useVersion = cert.Version1 + } else if n.Details.VpnAddr != nil { + detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = cert.Version2 + } else { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") + lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification") } return } + //TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? + //TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right? + //Simple check that the host sent this not someone else + if !slices.Contains(fromVpnAddrs, detailsVpnAddr) { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") + } + return + } + + relays := n.Details.GetRelays() + lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(vpnIp) + am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - - //TODO: IPV6-WORK - relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) - for i, _ := range n.Details.RelayVpnIp { - binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) - relays[i] = netip.AddrFrom4(b) - } - am.unlockedSetRelay(vpnIp, detailsVpnIp, relays) + am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - //TODO: IPV6-WORK - vpnIpB := vpnIp.As4() - n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:]) - ln, err := n.MarshalTo(lhh.pb) + if useVersion == cert.Version1 { + if !fromVpnAddrs[0].Is4() { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + return + } + vpnAddrB := fromVpnAddrs[0].As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) + } else { + lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") + return + } + ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host update ack") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { + //It's possible the lighthouse is communicating with us using a non primary vpn addr, + //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. + //maybe one day we'll have a better idea, if it matters. + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } @@ -1144,39 +1328,123 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp n }() if lhh.l.Level >= logrus.DebugLevel { - //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - //TODO: IPV6-WORK, make this debug line not suck - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b)) + var logVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + logVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) } } - for _, a := range n.Details.Ip4AndPorts { - punch(AddrPortFromIp4AndPort(a)) + for _, a := range n.Details.V4AddrPorts { + punch(protoV4AddrPortToNetAddrPort(a)) } - for _, a := range n.Details.Ip6AndPorts { - punch(AddrPortFromIp6AndPort(a)) + for _, a := range n.Details.V6AddrPorts { + punch(protoV6AddrPortToNetAddrPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - queryVpnIp := netip.AddrFrom4(b) + var queryVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", queryVpnIp) + lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. - w.SendMessageToVpnIp(header.Test, header.TestRequest, queryVpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }() } } + +func protoAddrToNetAddr(addr *Addr) netip.Addr { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], addr.Hi) + binary.BigEndian.PutUint64(b[8:], addr.Lo) + return netip.AddrFrom16(b).Unmap() +} + +func protoV4AddrPortToNetAddrPort(ap *V4AddrPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ap.Addr) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ap.Port)) +} + +func protoV6AddrPortToNetAddrPort(ap *V6AddrPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ap.Hi) + binary.BigEndian.PutUint64(b[8:], ap.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ap.Port)) +} + +func netAddrToProtoAddr(addr netip.Addr) *Addr { + b := addr.As16() + return &Addr{ + Hi: binary.BigEndian.Uint64(b[:8]), + Lo: binary.BigEndian.Uint64(b[8:]), + } +} + +func netAddrToProtoV4AddrPort(addr netip.Addr, port uint16) *V4AddrPort { + v4Addr := addr.As4() + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(port), + } +} + +func netAddrToProtoV6AddrPort(addr netip.Addr, port uint16) *V6AddrPort { + v6Addr := addr.As16() + return &V6AddrPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(port), + } +} + +func (d *NebulaMetaDetails) GetRelays() []netip.Addr { + var relays []netip.Addr + if len(d.OldRelayVpnAddrs) > 0 { + b := [4]byte{} + for _, r := range d.OldRelayVpnAddrs { + binary.BigEndian.PutUint32(b[:], r) + relays = append(relays, netip.AddrFrom4(b)) + } + } + + if len(d.RelayVpnAddrs) > 0 { + for _, r := range d.RelayVpnAddrs { + relays = append(relays, protoAddrToNetAddr(r)) + } + } + return relays +} + +// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able +func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) { + for i := range prefixes { + for j := range addrs { + if prefixes[i].Contains(addrs[j]) { + return addrs[j], true + } + } + } + return netip.Addr{}, false +} diff --git a/lighthouse_test.go b/lighthouse_test.go index 2599f5f..d5947aa 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -7,6 +7,8 @@ import ( "net/netip" "testing" + "github.com/gaissmai/bart" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" @@ -14,62 +16,51 @@ import ( "gopkg.in/yaml.v2" ) -//TODO: Add a test to ensure udpAddr is copied and not reused - func TestOldIPv4Only(t *testing.T) { // This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility b := []byte{8, 129, 130, 132, 80, 16, 10} - var m Ip4AndPort + var m V4AddrPort err := m.Unmarshal(b) assert.NoError(t, err) ip := netip.MustParseAddr("10.1.1.1") bp := ip.As4() - assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) -} - -func TestNewLhQuery(t *testing.T) { - myIp, err := netip.ParseAddr("192.1.1.1") - assert.NoError(t, err) - - // Generating a new lh query should work - a := NewLhQueryByInt(myIp) - - // The result should be a nebulameta protobuf - assert.IsType(t, &NebulaMeta{}, a) - - // It should also Marshal fine - b, err := a.Marshal() - assert.Nil(t, err) - - // and then Unmarshal fine - n := &NebulaMeta{} - err = n.Unmarshal(b) - assert.Nil(t, err) - + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) } func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.Nil(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) @@ -79,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) { } c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -99,9 +90,15 @@ func TestReloadLighthouseInterval(t *testing.T) { func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/0") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } c := config.NewC(l) - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) if !assert.NoError(b, err) { b.Fatal() } @@ -110,46 +107,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") vpnIp3 := netip.MustParseAddr("0.0.0.3") - lh.addrMap[vpnIp3] = NewRemoteList(nil) + lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil) lh.addrMap[vpnIp3].unlockedSetV4( vpnIp3, vpnIp3, - []*Ip4AndPort{ - NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), - NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), + []*V4AddrPort{ + netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()), + netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()), }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rAddr := netip.MustParseAddrPort("1.2.2.3:12345") rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") vpnIp2 := netip.MustParseAddr("0.0.0.3") - lh.addrMap[vpnIp2] = NewRemoteList(nil) + lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil) lh.addrMap[vpnIp2].unlockedSetV4( vpnIp3, vpnIp3, - []*Ip4AndPort{ - NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), - NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), + []*V4AddrPort{ + netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()), + netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()), }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) mw := &mockEncWriter{} + hi := []netip.Addr{vpnIp2} b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 4, - Ip4AndPorts: nil, + OldVpnAddr: 4, + V4AddrPorts: nil, }, } p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, vpnIp2, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -157,15 +155,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 3, - Ip4AndPorts: nil, + OldVpnAddr: 3, + V4AddrPorts: nil, }, } p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, vpnIp2, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) } @@ -197,40 +195,49 @@ func TestLighthouse_Memory(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh.ifce = &mockEncWriter{} assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3) // Grow it back to 2 newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Have both hosts ask about the other r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Make sure we didn't get changed r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Ensure proper ordering and limiting // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) @@ -255,7 +262,7 @@ func TestLighthouse_Memory(t *testing.T) { r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray( t, - r.msg.Details.Ip4AndPorts, + r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, ) @@ -265,7 +272,7 @@ func TestLighthouse_Memory(t *testing.T) { good := netip.MustParseAddrPort("1.128.0.99:4242") newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, good) } func TestLighthouse_reload(t *testing.T) { @@ -273,7 +280,16 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.NoError(t, err) nc := map[interface{}]interface{}{ @@ -290,13 +306,16 @@ func TestLighthouse_reload(t *testing.T) { } func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { - //TODO: IPV6-WORK - bip := queryVpnIp.As4() req := &NebulaMeta{ - Type: NebulaMeta_HostQuery, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(bip[:]), - }, + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{}, + } + + if queryVpnIp.Is4() { + bip := queryVpnIp.As4() + req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:]) + } else { + req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp) } b, err := req.Marshal() @@ -308,23 +327,29 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, myVpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) return w.lastReply } func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { - //TODO: IPV6-WORK - bip := vpnIp.As4() req := &NebulaMeta{ - Type: NebulaMeta_HostUpdateNotification, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(bip[:]), - Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), - }, + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{}, } - for k, v := range addrs { - req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) + if vpnIp.Is4() { + bip := vpnIp.As4() + req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:]) + } else { + req.Details.VpnAddr = netAddrToProtoAddr(vpnIp) + } + + for _, v := range addrs { + if v.Addr().Is4() { + req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port())) + } else { + req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port())) + } } b, err := req.Marshal() @@ -333,75 +358,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, vpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) } -//TODO: this is a RemoteList test -//func Test_lhRemoteAllowList(t *testing.T) { -// l := NewLogger() -// c := NewConfig(l) -// c.Settings["remoteallowlist"] = map[interface{}]interface{}{ -// "10.20.0.0/12": false, -// } -// allowList, err := c.GetAllowList("remoteallowlist", false) -// assert.Nil(t, err) -// -// lh1 := "10.128.0.2" -// lh1IP := net.ParseIP(lh1) -// -// udpServer, _ := NewListener(l, "0.0.0.0", 0, true) -// -// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) -// lh.SetRemoteAllowList(allowList) -// -// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap -// remote1IP := net.ParseIP("10.20.0.3") -// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP)) -// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242)) -// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)]) -// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{})) -// -// // Make sure a good ip enters the cache and addrMap -// remote2IP := net.ParseIP("10.128.0.3") -// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false) -// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr) -// -// // Another good ip gets into the cache, ordering is inverted -// remote3IP := net.ParseIP("10.128.0.4") -// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false) -// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr) -// -// // If we exceed the length limit we should only have the most recent addresses -// addedAddrs := []*udpAddr{} -// for i := 0; i < 11; i++ { -// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false) -// // The first entry here is a duplicate, don't add it to the assert list -// if i != 0 { -// addedAddrs = append(addedAddrs, remoteUDPAddr) -// } -// } -// -// // We should only have the last 10 of what we tried to add -// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses") -// assertUdpAddrInArray( -// t, -// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), -// addedAddrs[0], -// addedAddrs[1], -// addedAddrs[2], -// addedAddrs[3], -// addedAddrs[4], -// addedAddrs[5], -// addedAddrs[6], -// addedAddrs[7], -// addedAddrs[8], -// addedAddrs[9], -// ) -//} - type testLhReply struct { nebType header.MessageType nebSubType header.MessageSubType @@ -410,8 +369,9 @@ type testLhReply struct { } type testEncWriter struct { - lastReply testLhReply - metaFilter *NebulaMeta_MessageType + lastReply testLhReply + metaFilter *NebulaMeta_MessageType + protocolVersion cert.Version } func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { @@ -426,7 +386,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M tw.lastReply = testLhReply{ nebType: t, nebSubType: st, - vpnIp: hostinfo.vpnIp, + vpnIp: hostinfo.vpnAddrs[0], msg: msg, } } @@ -436,7 +396,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -453,17 +413,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } } +func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return nil +} + +func (tw *testEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: tw.protocolVersion} +} + // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { +func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { - //TODO: IPV6-WORK - h := AddrPortFromIp4AndPort(have[k]) + h := protoV4AddrPortToNetAddrPort(have[k]) if !(h == w) { assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } + +func Test_findNetworkUnion(t *testing.T) { + var out netip.Addr + var ok bool + + tenDot := netip.MustParsePrefix("10.0.0.0/8") + oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16") + fe80 := netip.MustParsePrefix("fe80::/8") + fc00 := netip.MustParsePrefix("fc00::/7") + + a1 := netip.MustParseAddr("10.0.0.1") + afe81 := netip.MustParseAddr("fe80::1") + + //simple + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed lengths + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed family + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //ordering + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //some mismatches + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //falsey cases + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) +} diff --git a/main.go b/main.go index e60dbd9..aa7a2bc 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package nebula import ( "context" - "encoding/binary" "fmt" "net" "net/netip" @@ -61,25 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - certificate := pki.GetCertState().Certificate - fw, err := NewFirewallFromConfig(l, certificate, c) + fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - ones, _ := certificate.Details.Ips[0].Mask.Size() - addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) - if !ok { - err = util.NewContextualError( - "Invalid ip address in certificate", - m{"vpnIp": certificate.Details.Ips[0].IP}, - nil, - ) - return nil, err - } - tunCidr := netip.PrefixFrom(addr, ones) - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) @@ -142,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, tunCidr, routines) + tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } @@ -197,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - hostMap := NewHostMapFromConfig(l, tunCidr, c) + hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) - lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } @@ -242,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Inside: tun, Outside: udpConns[0], pki: pki, - Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, HandshakeManager: handshakeManager, @@ -264,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg l: l, } - switch ifConfig.Cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian - default: - return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) - } - var ifce *Interface if !configTest { ifce, err = NewInterface(ctx, ifConfig) @@ -280,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, fmt.Errorf("failed to initialize interface: %s", err) } - // TODO: Better way to attach these, probably want a new interface in InterfaceConfig - // I don't want to make this initial commit too far-reaching though ifce.writers = udpConns lightHouse.ifce = ifce @@ -326,8 +300,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg go handshakeManager.Run(ctx) } - // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept - // a context so that they can exit when the context is Done. statsStart, err := startStats(l, c, buildVersion, configTest) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) @@ -337,7 +309,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, nil } - //TODO: check if we _should_ be emitting stats go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) attachCommands(l, c, ssh, ifce) @@ -346,7 +317,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg var dnsStart func() if lightHouse.amLighthouse && serveDns { l.Debugln("Starting dns server") - dnsStart = dnsMain(l, hostMap, c) + dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) } return &Control{ diff --git a/message_metrics.go b/message_metrics.go index 94bb02f..10e8472 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -7,8 +7,6 @@ import ( "github.com/slackhq/nebula/header" ) -//TODO: this can probably move into the header package - type MessageMetrics struct { rx [][]metrics.Counter tx [][]metrics.Counter diff --git a/nebula.pb.go b/nebula.pb.go index a753312..946551b 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -96,7 +96,7 @@ func (x NebulaPing_MessageType) String() string { } func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4, 0} + return fileDescriptor_2d65afa7693df5ef, []int{5, 0} } type NebulaControl_MessageType int32 @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8, 0} + return fileDescriptor_2d65afa7693df5ef, []int{9, 0} } type NebulaMeta struct { @@ -180,11 +180,13 @@ func (m *NebulaMeta) GetDetails() *NebulaMetaDetails { } type NebulaMetaDetails struct { - VpnIp uint32 `protobuf:"varint,1,opt,name=VpnIp,proto3" json:"VpnIp,omitempty"` - Ip4AndPorts []*Ip4AndPort `protobuf:"bytes,2,rep,name=Ip4AndPorts,proto3" json:"Ip4AndPorts,omitempty"` - Ip6AndPorts []*Ip6AndPort `protobuf:"bytes,4,rep,name=Ip6AndPorts,proto3" json:"Ip6AndPorts,omitempty"` - RelayVpnIp []uint32 `protobuf:"varint,5,rep,packed,name=RelayVpnIp,proto3" json:"RelayVpnIp,omitempty"` - Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` + OldVpnAddr uint32 `protobuf:"varint,1,opt,name=OldVpnAddr,proto3" json:"OldVpnAddr,omitempty"` // Deprecated: Do not use. + VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` + OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. + RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` + V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` + V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` + Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` } func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} } @@ -220,30 +222,46 @@ func (m *NebulaMetaDetails) XXX_DiscardUnknown() { var xxx_messageInfo_NebulaMetaDetails proto.InternalMessageInfo -func (m *NebulaMetaDetails) GetVpnIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldVpnAddr() uint32 { if m != nil { - return m.VpnIp + return m.OldVpnAddr } return 0 } -func (m *NebulaMetaDetails) GetIp4AndPorts() []*Ip4AndPort { +func (m *NebulaMetaDetails) GetVpnAddr() *Addr { if m != nil { - return m.Ip4AndPorts + return m.VpnAddr } return nil } -func (m *NebulaMetaDetails) GetIp6AndPorts() []*Ip6AndPort { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldRelayVpnAddrs() []uint32 { if m != nil { - return m.Ip6AndPorts + return m.OldRelayVpnAddrs } return nil } -func (m *NebulaMetaDetails) GetRelayVpnIp() []uint32 { +func (m *NebulaMetaDetails) GetRelayVpnAddrs() []*Addr { if m != nil { - return m.RelayVpnIp + return m.RelayVpnAddrs + } + return nil +} + +func (m *NebulaMetaDetails) GetV4AddrPorts() []*V4AddrPort { + if m != nil { + return m.V4AddrPorts + } + return nil +} + +func (m *NebulaMetaDetails) GetV6AddrPorts() []*V6AddrPort { + if m != nil { + return m.V6AddrPorts } return nil } @@ -255,23 +273,23 @@ func (m *NebulaMetaDetails) GetCounter() uint32 { return 0 } -type Ip4AndPort struct { - Ip uint32 `protobuf:"varint,1,opt,name=Ip,proto3" json:"Ip,omitempty"` - Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +type Addr struct { + Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` + Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` } -func (m *Ip4AndPort) Reset() { *m = Ip4AndPort{} } -func (m *Ip4AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip4AndPort) ProtoMessage() {} -func (*Ip4AndPort) Descriptor() ([]byte, []int) { +func (m *Addr) Reset() { *m = Addr{} } +func (m *Addr) String() string { return proto.CompactTextString(m) } +func (*Addr) ProtoMessage() {} +func (*Addr) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{2} } -func (m *Ip4AndPort) XXX_Unmarshal(b []byte) error { +func (m *Addr) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *Addr) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Ip4AndPort.Marshal(b, m, deterministic) + return xxx_messageInfo_Addr.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -281,86 +299,138 @@ func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Ip4AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip4AndPort.Merge(m, src) +func (m *Addr) XXX_Merge(src proto.Message) { + xxx_messageInfo_Addr.Merge(m, src) } -func (m *Ip4AndPort) XXX_Size() int { +func (m *Addr) XXX_Size() int { return m.Size() } -func (m *Ip4AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip4AndPort.DiscardUnknown(m) +func (m *Addr) XXX_DiscardUnknown() { + xxx_messageInfo_Addr.DiscardUnknown(m) } -var xxx_messageInfo_Ip4AndPort proto.InternalMessageInfo +var xxx_messageInfo_Addr proto.InternalMessageInfo -func (m *Ip4AndPort) GetIp() uint32 { - if m != nil { - return m.Ip - } - return 0 -} - -func (m *Ip4AndPort) GetPort() uint32 { - if m != nil { - return m.Port - } - return 0 -} - -type Ip6AndPort struct { - Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` - Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` - Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` -} - -func (m *Ip6AndPort) Reset() { *m = Ip6AndPort{} } -func (m *Ip6AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip6AndPort) ProtoMessage() {} -func (*Ip6AndPort) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{3} -} -func (m *Ip6AndPort) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *Ip6AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_Ip6AndPort.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *Ip6AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip6AndPort.Merge(m, src) -} -func (m *Ip6AndPort) XXX_Size() int { - return m.Size() -} -func (m *Ip6AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip6AndPort.DiscardUnknown(m) -} - -var xxx_messageInfo_Ip6AndPort proto.InternalMessageInfo - -func (m *Ip6AndPort) GetHi() uint64 { +func (m *Addr) GetHi() uint64 { if m != nil { return m.Hi } return 0 } -func (m *Ip6AndPort) GetLo() uint64 { +func (m *Addr) GetLo() uint64 { if m != nil { return m.Lo } return 0 } -func (m *Ip6AndPort) GetPort() uint32 { +type V4AddrPort struct { + Addr uint32 `protobuf:"varint,1,opt,name=Addr,proto3" json:"Addr,omitempty"` + Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +} + +func (m *V4AddrPort) Reset() { *m = V4AddrPort{} } +func (m *V4AddrPort) String() string { return proto.CompactTextString(m) } +func (*V4AddrPort) ProtoMessage() {} +func (*V4AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{3} +} +func (m *V4AddrPort) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *V4AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_V4AddrPort.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *V4AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V4AddrPort.Merge(m, src) +} +func (m *V4AddrPort) XXX_Size() int { + return m.Size() +} +func (m *V4AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V4AddrPort.DiscardUnknown(m) +} + +var xxx_messageInfo_V4AddrPort proto.InternalMessageInfo + +func (m *V4AddrPort) GetAddr() uint32 { + if m != nil { + return m.Addr + } + return 0 +} + +func (m *V4AddrPort) GetPort() uint32 { + if m != nil { + return m.Port + } + return 0 +} + +type V6AddrPort struct { + Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` + Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` + Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` +} + +func (m *V6AddrPort) Reset() { *m = V6AddrPort{} } +func (m *V6AddrPort) String() string { return proto.CompactTextString(m) } +func (*V6AddrPort) ProtoMessage() {} +func (*V6AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{4} +} +func (m *V6AddrPort) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *V6AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_V6AddrPort.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *V6AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V6AddrPort.Merge(m, src) +} +func (m *V6AddrPort) XXX_Size() int { + return m.Size() +} +func (m *V6AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V6AddrPort.DiscardUnknown(m) +} + +var xxx_messageInfo_V6AddrPort proto.InternalMessageInfo + +func (m *V6AddrPort) GetHi() uint64 { + if m != nil { + return m.Hi + } + return 0 +} + +func (m *V6AddrPort) GetLo() uint64 { + if m != nil { + return m.Lo + } + return 0 +} + +func (m *V6AddrPort) GetPort() uint32 { if m != nil { return m.Port } @@ -376,7 +446,7 @@ func (m *NebulaPing) Reset() { *m = NebulaPing{} } func (m *NebulaPing) String() string { return proto.CompactTextString(m) } func (*NebulaPing) ProtoMessage() {} func (*NebulaPing) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4} + return fileDescriptor_2d65afa7693df5ef, []int{5} } func (m *NebulaPing) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -428,7 +498,7 @@ func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } func (*NebulaHandshake) ProtoMessage() {} func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{5} + return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -482,7 +552,7 @@ func (m *MultiPortDetails) Reset() { *m = MultiPortDetails{} } func (m *MultiPortDetails) String() string { return proto.CompactTextString(m) } func (*MultiPortDetails) ProtoMessage() {} func (*MultiPortDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} + return fileDescriptor_2d65afa7693df5ef, []int{7} } func (m *MultiPortDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -545,6 +615,7 @@ type NebulaHandshakeDetails struct { ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` + CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` InitiatorMultiPort *MultiPortDetails `protobuf:"bytes,6,opt,name=InitiatorMultiPort,proto3" json:"InitiatorMultiPort,omitempty"` ResponderMultiPort *MultiPortDetails `protobuf:"bytes,7,opt,name=ResponderMultiPort,proto3" json:"ResponderMultiPort,omitempty"` } @@ -553,7 +624,7 @@ func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } func (*NebulaHandshakeDetails) ProtoMessage() {} func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} + return fileDescriptor_2d65afa7693df5ef, []int{8} } func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -617,6 +688,13 @@ func (m *NebulaHandshakeDetails) GetTime() uint64 { return 0 } +func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { + if m != nil { + return m.CertVersion + } + return 0 +} + func (m *NebulaHandshakeDetails) GetInitiatorMultiPort() *MultiPortDetails { if m != nil { return m.InitiatorMultiPort @@ -635,15 +713,17 @@ type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` ResponderRelayIndex uint32 `protobuf:"varint,3,opt,name=ResponderRelayIndex,proto3" json:"ResponderRelayIndex,omitempty"` - RelayToIp uint32 `protobuf:"varint,4,opt,name=RelayToIp,proto3" json:"RelayToIp,omitempty"` - RelayFromIp uint32 `protobuf:"varint,5,opt,name=RelayFromIp,proto3" json:"RelayFromIp,omitempty"` + OldRelayToAddr uint32 `protobuf:"varint,4,opt,name=OldRelayToAddr,proto3" json:"OldRelayToAddr,omitempty"` // Deprecated: Do not use. + OldRelayFromAddr uint32 `protobuf:"varint,5,opt,name=OldRelayFromAddr,proto3" json:"OldRelayFromAddr,omitempty"` // Deprecated: Do not use. + RelayToAddr *Addr `protobuf:"bytes,6,opt,name=RelayToAddr,proto3" json:"RelayToAddr,omitempty"` + RelayFromAddr *Addr `protobuf:"bytes,7,opt,name=RelayFromAddr,proto3" json:"RelayFromAddr,omitempty"` } func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8} + return fileDescriptor_2d65afa7693df5ef, []int{9} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -693,28 +773,45 @@ func (m *NebulaControl) GetResponderRelayIndex() uint32 { return 0 } -func (m *NebulaControl) GetRelayToIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayToAddr() uint32 { if m != nil { - return m.RelayToIp + return m.OldRelayToAddr } return 0 } -func (m *NebulaControl) GetRelayFromIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayFromAddr() uint32 { if m != nil { - return m.RelayFromIp + return m.OldRelayFromAddr } return 0 } +func (m *NebulaControl) GetRelayToAddr() *Addr { + if m != nil { + return m.RelayToAddr + } + return nil +} + +func (m *NebulaControl) GetRelayFromAddr() *Addr { + if m != nil { + return m.RelayFromAddr + } + return nil +} + func init() { proto.RegisterEnum("nebula.NebulaMeta_MessageType", NebulaMeta_MessageType_name, NebulaMeta_MessageType_value) proto.RegisterEnum("nebula.NebulaPing_MessageType", NebulaPing_MessageType_name, NebulaPing_MessageType_value) proto.RegisterEnum("nebula.NebulaControl_MessageType", NebulaControl_MessageType_name, NebulaControl_MessageType_value) proto.RegisterType((*NebulaMeta)(nil), "nebula.NebulaMeta") proto.RegisterType((*NebulaMetaDetails)(nil), "nebula.NebulaMetaDetails") - proto.RegisterType((*Ip4AndPort)(nil), "nebula.Ip4AndPort") - proto.RegisterType((*Ip6AndPort)(nil), "nebula.Ip6AndPort") + proto.RegisterType((*Addr)(nil), "nebula.Addr") + proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") + proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") proto.RegisterType((*MultiPortDetails)(nil), "nebula.MultiPortDetails") @@ -725,56 +822,61 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 784 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x8e, 0xe3, 0x44, - 0x10, 0x8e, 0x1d, 0xe7, 0xaf, 0x32, 0xc9, 0x9a, 0x1a, 0x08, 0xc9, 0x0a, 0xac, 0xe0, 0x03, 0xca, - 0x29, 0xbb, 0xca, 0x2c, 0x23, 0x8e, 0xec, 0x06, 0xa1, 0x44, 0xda, 0x8c, 0x42, 0x13, 0x40, 0xe2, - 0x82, 0x7a, 0x9c, 0x66, 0x62, 0xc5, 0x71, 0x7b, 0xed, 0x36, 0x9a, 0xbc, 0x05, 0xe2, 0x59, 0x38, - 0xf2, 0x00, 0xdc, 0xd8, 0x23, 0x47, 0x34, 0x73, 0xe4, 0xc8, 0x0b, 0xa0, 0x6e, 0xff, 0xe6, 0x07, - 0xb8, 0x75, 0x55, 0x7d, 0x5f, 0xf5, 0xd7, 0x55, 0x5f, 0x1c, 0xb8, 0xf0, 0xd9, 0x6d, 0xec, 0xd1, - 0x71, 0x10, 0x72, 0xc1, 0xb1, 0x9e, 0x44, 0xf6, 0x5f, 0x3a, 0xc0, 0x8d, 0x3a, 0x2e, 0x98, 0xa0, - 0x38, 0x01, 0x63, 0xb5, 0x0f, 0x58, 0x5f, 0x1b, 0x6a, 0xa3, 0xee, 0xc4, 0x1a, 0xa7, 0x9c, 0x02, - 0x31, 0x5e, 0xb0, 0x28, 0xa2, 0x77, 0x4c, 0xa2, 0x88, 0xc2, 0xe2, 0x15, 0x34, 0x3e, 0x67, 0x82, - 0xba, 0x5e, 0xd4, 0xd7, 0x87, 0xda, 0xa8, 0x3d, 0x19, 0x9c, 0xd2, 0x52, 0x00, 0xc9, 0x90, 0xf6, - 0xdf, 0x1a, 0xb4, 0x4b, 0xad, 0xb0, 0x09, 0xc6, 0x0d, 0xf7, 0x99, 0x59, 0xc1, 0x0e, 0xb4, 0x66, - 0x3c, 0x12, 0x5f, 0xc6, 0x2c, 0xdc, 0x9b, 0x1a, 0x22, 0x74, 0xf3, 0x90, 0xb0, 0xc0, 0xdb, 0x9b, - 0x3a, 0x3e, 0x85, 0x9e, 0xcc, 0x7d, 0x1d, 0xac, 0xa9, 0x60, 0x37, 0x5c, 0xb8, 0x3f, 0xb8, 0x0e, - 0x15, 0x2e, 0xf7, 0xcd, 0x2a, 0x0e, 0xe0, 0x3d, 0x59, 0x5b, 0xf0, 0x1f, 0xd9, 0xfa, 0xa0, 0x64, - 0x64, 0xa5, 0x65, 0xec, 0x3b, 0x9b, 0x83, 0x52, 0x0d, 0xbb, 0x00, 0xb2, 0xf4, 0xed, 0x86, 0xd3, - 0x9d, 0x6b, 0xd6, 0xf1, 0x12, 0x9e, 0x14, 0x71, 0x72, 0x6d, 0x43, 0x2a, 0x5b, 0x52, 0xb1, 0x99, - 0x6e, 0x98, 0xb3, 0x35, 0x9b, 0x52, 0x59, 0x1e, 0x26, 0x90, 0x16, 0x7e, 0x08, 0x83, 0xf3, 0xca, - 0x5e, 0x3a, 0x5b, 0x13, 0xec, 0xdf, 0x35, 0x78, 0xe7, 0x64, 0x28, 0xf8, 0x2e, 0xd4, 0xbe, 0x09, - 0xfc, 0x79, 0xa0, 0xa6, 0xde, 0x21, 0x49, 0x80, 0x2f, 0xa0, 0x3d, 0x0f, 0x5e, 0xbc, 0xf4, 0xd7, - 0x4b, 0x1e, 0x0a, 0x39, 0xda, 0xea, 0xa8, 0x3d, 0xc1, 0x6c, 0xb4, 0x45, 0x89, 0x94, 0x61, 0x09, - 0xeb, 0x3a, 0x67, 0x19, 0xc7, 0xac, 0xeb, 0x12, 0x2b, 0x87, 0xa1, 0x05, 0x40, 0x98, 0x47, 0xf7, - 0x89, 0x8c, 0xda, 0xb0, 0x3a, 0xea, 0x90, 0x52, 0x06, 0xfb, 0xd0, 0x70, 0x78, 0xec, 0x0b, 0x16, - 0xf6, 0xab, 0x4a, 0x63, 0x16, 0xda, 0xcf, 0x01, 0x8a, 0xeb, 0xb1, 0x0b, 0x7a, 0xfe, 0x0c, 0x7d, - 0x1e, 0x20, 0x82, 0x21, 0xf3, 0xca, 0x17, 0x1d, 0xa2, 0xce, 0xf6, 0x67, 0x92, 0x71, 0x5d, 0x62, - 0xcc, 0x5c, 0xc5, 0x30, 0x88, 0x3e, 0x73, 0x65, 0xfc, 0x9a, 0x2b, 0xbc, 0x41, 0xf4, 0xd7, 0x3c, - 0xef, 0x50, 0x2d, 0x75, 0xb8, 0xcf, 0x2c, 0xbb, 0x74, 0xfd, 0xbb, 0xff, 0xb6, 0xac, 0x44, 0x9c, - 0xb1, 0x2c, 0x82, 0xb1, 0x72, 0x77, 0x2c, 0xbd, 0x47, 0x9d, 0x6d, 0xfb, 0xc4, 0x90, 0x92, 0x6c, - 0x56, 0xb0, 0x05, 0xb5, 0x64, 0xbd, 0x9a, 0xfd, 0x3d, 0x3c, 0x49, 0xfa, 0xce, 0xa8, 0xbf, 0x8e, - 0x36, 0x74, 0xcb, 0xf0, 0xd3, 0xc2, 0xfd, 0x9a, 0x72, 0xff, 0x91, 0x82, 0x1c, 0x79, 0xfc, 0x13, - 0x90, 0x22, 0x66, 0x3b, 0xea, 0x28, 0x11, 0x17, 0x44, 0x9d, 0xed, 0x9f, 0x35, 0x30, 0x17, 0xb1, - 0x27, 0x5c, 0xf9, 0xd0, 0x0c, 0x38, 0x84, 0x36, 0xb9, 0xff, 0x2a, 0x0e, 0x02, 0x1e, 0x0a, 0xb6, - 0x56, 0xd7, 0x34, 0x49, 0x39, 0x25, 0x11, 0xab, 0x12, 0x42, 0x4f, 0x10, 0xa5, 0x14, 0x3e, 0x85, - 0xe6, 0x2b, 0x1a, 0xb1, 0xd2, 0x2c, 0xf3, 0x58, 0x6e, 0x7f, 0xc5, 0x05, 0xf5, 0x32, 0xcb, 0xc8, - 0x6a, 0x29, 0x63, 0xff, 0xaa, 0x43, 0xef, 0xfc, 0x63, 0xe4, 0x1b, 0xa6, 0x2c, 0x14, 0x4a, 0xd3, - 0x05, 0x51, 0x67, 0xfc, 0x18, 0xba, 0x73, 0xdf, 0x15, 0x2e, 0x15, 0x3c, 0x9c, 0xfb, 0x6b, 0x76, - 0x9f, 0xae, 0xff, 0x28, 0x2b, 0x71, 0x84, 0x45, 0x01, 0xf7, 0xd7, 0x2c, 0xc5, 0x25, 0xc2, 0x8e, - 0xb2, 0xd8, 0x83, 0xfa, 0x94, 0xf3, 0xad, 0xcb, 0x94, 0x34, 0x83, 0xa4, 0x51, 0xbe, 0xc4, 0x5a, - 0xb1, 0x44, 0x9c, 0x01, 0xe6, 0xb7, 0xe4, 0x73, 0xec, 0xd7, 0xd5, 0x62, 0xfa, 0xd9, 0x62, 0x8e, - 0x07, 0x4c, 0xce, 0x70, 0x64, 0xa7, 0x5c, 0x47, 0xd1, 0xa9, 0xf1, 0x7f, 0x9d, 0x4e, 0x39, 0xf6, - 0x2f, 0x3a, 0x74, 0x92, 0xf1, 0x4d, 0xb9, 0x2f, 0x42, 0xee, 0xe1, 0x27, 0x07, 0x96, 0xfd, 0xe8, - 0xd0, 0x30, 0x29, 0xe8, 0x8c, 0x6b, 0x9f, 0xc3, 0x65, 0x2e, 0x54, 0xfd, 0x38, 0xcb, 0xd3, 0x3d, - 0x57, 0x92, 0x8c, 0x5c, 0x50, 0x89, 0x91, 0xcc, 0xf9, 0x5c, 0x09, 0x3f, 0x80, 0x96, 0x8a, 0x56, - 0x7c, 0x1e, 0xa4, 0x56, 0x28, 0x12, 0xca, 0x89, 0x32, 0xf8, 0x22, 0xe4, 0x3b, 0xf5, 0xa1, 0x90, - 0xf5, 0x72, 0xca, 0x9e, 0xfd, 0xdb, 0x67, 0xbd, 0x07, 0x38, 0x0d, 0x19, 0x15, 0x4c, 0xa1, 0x09, - 0x7b, 0x13, 0xb3, 0x48, 0x98, 0x1a, 0xbe, 0x0f, 0x97, 0x07, 0x79, 0x29, 0x29, 0x62, 0xa6, 0xfe, - 0xea, 0xea, 0xb7, 0x07, 0x4b, 0x7b, 0xfb, 0x60, 0x69, 0x7f, 0x3e, 0x58, 0xda, 0x4f, 0x8f, 0x56, - 0xe5, 0xed, 0xa3, 0x55, 0xf9, 0xe3, 0xd1, 0xaa, 0x7c, 0x37, 0xb8, 0x73, 0xc5, 0x26, 0xbe, 0x1d, - 0x3b, 0x7c, 0xf7, 0x2c, 0xf2, 0xa8, 0xb3, 0xdd, 0xbc, 0x79, 0x96, 0x8c, 0xf0, 0xb6, 0xae, 0xfe, - 0xdd, 0xae, 0xfe, 0x09, 0x00, 0x00, 0xff, 0xff, 0xb2, 0xb5, 0xba, 0xcc, 0xed, 0x06, 0x00, 0x00, + // 864 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x56, 0x4f, 0x6f, 0xe3, 0x44, + 0x14, 0x8f, 0x1d, 0xe7, 0x4f, 0x5f, 0x9a, 0xac, 0x79, 0x15, 0x25, 0x5d, 0x89, 0x28, 0xf8, 0x50, + 0xad, 0x38, 0x64, 0x51, 0x5b, 0x56, 0x1c, 0xd9, 0x06, 0xa1, 0xac, 0xb4, 0xed, 0x96, 0x21, 0x14, + 0x89, 0x0b, 0x9a, 0xc6, 0x43, 0x63, 0xc5, 0xf1, 0x78, 0xed, 0x31, 0x6a, 0xbe, 0x05, 0xe2, 0xb3, + 0xf0, 0x21, 0xe0, 0xb6, 0x47, 0x4e, 0x08, 0xb5, 0x47, 0x8e, 0x7c, 0x01, 0x34, 0xe3, 0x7f, 0xe3, + 0xc4, 0x6c, 0x6f, 0xf3, 0xde, 0xef, 0xf7, 0x7b, 0xfe, 0xcd, 0x9b, 0x79, 0x93, 0xc0, 0x7e, 0xc0, + 0x6e, 0x12, 0x9f, 0x4e, 0xc2, 0x88, 0x0b, 0x8e, 0xed, 0x34, 0x72, 0xfe, 0x31, 0x01, 0x2e, 0xd5, + 0xf2, 0x82, 0x09, 0x8a, 0x27, 0x60, 0xcd, 0x37, 0x21, 0x1b, 0x1a, 0x63, 0xe3, 0xd9, 0xe0, 0x64, + 0x34, 0xc9, 0x34, 0x25, 0x63, 0x72, 0xc1, 0xe2, 0x98, 0xde, 0x32, 0xc9, 0x22, 0x8a, 0x8b, 0xa7, + 0xd0, 0xf9, 0x8a, 0x09, 0xea, 0xf9, 0xf1, 0xd0, 0x1c, 0x1b, 0xcf, 0x7a, 0x27, 0x47, 0xbb, 0xb2, + 0x8c, 0x40, 0x72, 0xa6, 0xf3, 0xaf, 0x01, 0x3d, 0xad, 0x14, 0x76, 0xc1, 0xba, 0xe4, 0x01, 0xb3, + 0x1b, 0xd8, 0x87, 0xbd, 0x19, 0x8f, 0xc5, 0x37, 0x09, 0x8b, 0x36, 0xb6, 0x81, 0x08, 0x83, 0x22, + 0x24, 0x2c, 0xf4, 0x37, 0xb6, 0x89, 0x4f, 0xe1, 0x50, 0xe6, 0xbe, 0x0b, 0x5d, 0x2a, 0xd8, 0x25, + 0x17, 0xde, 0x4f, 0xde, 0x82, 0x0a, 0x8f, 0x07, 0x76, 0x13, 0x8f, 0xe0, 0x43, 0x89, 0x5d, 0xf0, + 0x9f, 0x99, 0x5b, 0x81, 0xac, 0x1c, 0xba, 0x4a, 0x82, 0xc5, 0xb2, 0x02, 0xb5, 0x70, 0x00, 0x20, + 0xa1, 0xef, 0x97, 0x9c, 0xae, 0x3d, 0xbb, 0x8d, 0x07, 0xf0, 0xa4, 0x8c, 0xd3, 0xcf, 0x76, 0xa4, + 0xb3, 0x2b, 0x2a, 0x96, 0xd3, 0x25, 0x5b, 0xac, 0xec, 0xae, 0x74, 0x56, 0x84, 0x29, 0x65, 0x0f, + 0x3f, 0x86, 0xa3, 0x7a, 0x67, 0x2f, 0x17, 0x2b, 0x1b, 0x9c, 0x3f, 0x4c, 0xf8, 0x60, 0xa7, 0x29, + 0xe8, 0x00, 0xbc, 0xf1, 0xdd, 0xeb, 0x30, 0x78, 0xe9, 0xba, 0x91, 0x6a, 0x7d, 0xff, 0xdc, 0x1c, + 0x1a, 0x44, 0xcb, 0xe2, 0x31, 0x74, 0x72, 0x42, 0x5b, 0x35, 0x79, 0x3f, 0x6f, 0xb2, 0xcc, 0x91, + 0x1c, 0xc4, 0x09, 0xd8, 0x6f, 0x7c, 0x97, 0x30, 0x9f, 0x6e, 0xb2, 0x54, 0x3c, 0x6c, 0x8d, 0x9b, + 0x59, 0xc5, 0x1d, 0x0c, 0x4f, 0xa0, 0x5f, 0x25, 0x77, 0xc6, 0xcd, 0x9d, 0xea, 0x55, 0x0a, 0x9e, + 0x41, 0xef, 0xfa, 0x4c, 0x2e, 0xaf, 0x78, 0x24, 0xe4, 0xa1, 0x4b, 0x05, 0xe6, 0x8a, 0x12, 0x22, + 0x3a, 0x4d, 0xa9, 0x5e, 0x94, 0x2a, 0x6b, 0x4b, 0xf5, 0x42, 0x53, 0x95, 0x34, 0x1c, 0x42, 0x67, + 0xc1, 0x93, 0x40, 0xb0, 0x68, 0xd8, 0x94, 0x8d, 0x21, 0x79, 0xe8, 0x1c, 0x83, 0xa5, 0x76, 0x3c, + 0x00, 0x73, 0xe6, 0xa9, 0xae, 0x59, 0xc4, 0x9c, 0x79, 0x32, 0x7e, 0xcd, 0xd5, 0x4d, 0xb4, 0x88, + 0xf9, 0x9a, 0x3b, 0x67, 0x00, 0xa5, 0x0d, 0xc4, 0x54, 0x95, 0x76, 0x99, 0xa4, 0x15, 0x10, 0x2c, + 0x89, 0x29, 0x4d, 0x9f, 0xa8, 0xb5, 0xf3, 0x25, 0x40, 0x69, 0xe3, 0xb1, 0x6f, 0x14, 0x15, 0x9a, + 0x5a, 0x85, 0xbb, 0x7c, 0xb0, 0xae, 0xbc, 0xe0, 0xf6, 0xfd, 0x83, 0x25, 0x19, 0x35, 0x83, 0x85, + 0x60, 0xcd, 0xbd, 0x35, 0xcb, 0xbe, 0xa3, 0xd6, 0x8e, 0xb3, 0x33, 0x36, 0x52, 0x6c, 0x37, 0x70, + 0x0f, 0x5a, 0xe9, 0x25, 0x34, 0x9c, 0x1f, 0xe1, 0x49, 0x5a, 0x77, 0x46, 0x03, 0x37, 0x5e, 0xd2, + 0x15, 0xc3, 0x2f, 0xca, 0x19, 0x35, 0xd4, 0xf5, 0xd9, 0x72, 0x50, 0x30, 0xb7, 0x07, 0x55, 0x9a, + 0x98, 0xad, 0xe9, 0x42, 0x99, 0xd8, 0x27, 0x6a, 0xed, 0xfc, 0x6a, 0x80, 0x7d, 0x91, 0xf8, 0xc2, + 0x93, 0x1b, 0xcd, 0x89, 0x63, 0xe8, 0x91, 0xbb, 0x6f, 0x93, 0x30, 0xe4, 0x91, 0x60, 0xae, 0xfa, + 0x4c, 0x97, 0xe8, 0x29, 0xc9, 0x98, 0x6b, 0x0c, 0x33, 0x65, 0x68, 0x29, 0x7c, 0x0a, 0xdd, 0x73, + 0x1a, 0x33, 0xad, 0x97, 0x45, 0x8c, 0x23, 0x80, 0x39, 0x17, 0xd4, 0xcf, 0xaf, 0x8f, 0x44, 0xb5, + 0x8c, 0xf3, 0x97, 0x09, 0x87, 0xf5, 0x9b, 0x91, 0x7b, 0x98, 0xb2, 0x48, 0x28, 0x4f, 0xfb, 0x44, + 0xad, 0xf1, 0x18, 0x06, 0xaf, 0x02, 0x4f, 0x78, 0x54, 0xf0, 0xe8, 0x55, 0xe0, 0xb2, 0xbb, 0xec, + 0xf8, 0xb7, 0xb2, 0x92, 0x47, 0x58, 0x1c, 0xf2, 0xc0, 0x65, 0x19, 0x2f, 0x35, 0xb6, 0x95, 0xc5, + 0x43, 0x68, 0x4f, 0x39, 0x5f, 0x79, 0x4c, 0x59, 0xb3, 0x48, 0x16, 0x15, 0x87, 0xd8, 0x2a, 0x0f, + 0x51, 0x36, 0x42, 0x7a, 0xb8, 0x66, 0x51, 0xec, 0xf1, 0x60, 0xd8, 0x55, 0x05, 0xf5, 0x14, 0xce, + 0x00, 0x0b, 0x1f, 0x45, 0xa7, 0xb3, 0xc9, 0x1f, 0xe6, 0x47, 0xb7, 0x7d, 0x04, 0xa4, 0x46, 0x23, + 0x2b, 0x15, 0x4e, 0xcb, 0x4a, 0x9d, 0xc7, 0x2a, 0xed, 0x6a, 0x9c, 0xdf, 0x9a, 0xd0, 0x4f, 0x1b, + 0x3c, 0xe5, 0x81, 0x88, 0xb8, 0x8f, 0x9f, 0x57, 0x2e, 0xf5, 0x27, 0xd5, 0x2b, 0x95, 0x91, 0x6a, + 0xee, 0xf5, 0x67, 0x70, 0x50, 0x18, 0x55, 0x2f, 0x8b, 0xde, 0xff, 0x3a, 0x48, 0x2a, 0x0a, 0x43, + 0x9a, 0x22, 0x3d, 0x89, 0x3a, 0x08, 0x3f, 0x85, 0x41, 0xfe, 0xd6, 0xcd, 0xb9, 0x9a, 0x78, 0xab, + 0x78, 0x57, 0xb7, 0x10, 0xfd, 0xcd, 0xfc, 0x3a, 0xe2, 0x6b, 0xc5, 0x6e, 0x15, 0xec, 0x1d, 0x0c, + 0x27, 0xd0, 0xd3, 0x0b, 0xd7, 0xbd, 0xc7, 0x3a, 0xa1, 0x78, 0x63, 0x8b, 0xe2, 0x9d, 0x1a, 0x45, + 0x95, 0xe2, 0xcc, 0xfe, 0xef, 0xe7, 0xf1, 0x10, 0x70, 0x1a, 0x31, 0x2a, 0x98, 0xe2, 0x13, 0xf6, + 0x36, 0x61, 0xb1, 0xb0, 0x0d, 0xfc, 0x08, 0x0e, 0x2a, 0x79, 0xd9, 0x92, 0x98, 0xd9, 0xe6, 0xf9, + 0xe9, 0xef, 0xf7, 0x23, 0xe3, 0xdd, 0xfd, 0xc8, 0xf8, 0xfb, 0x7e, 0x64, 0xfc, 0xf2, 0x30, 0x6a, + 0xbc, 0x7b, 0x18, 0x35, 0xfe, 0x7c, 0x18, 0x35, 0x7e, 0x38, 0xba, 0xf5, 0xc4, 0x32, 0xb9, 0x99, + 0x2c, 0xf8, 0xfa, 0x79, 0xec, 0xd3, 0xc5, 0x6a, 0xf9, 0xf6, 0x79, 0x6a, 0xe9, 0xa6, 0xad, 0xfe, + 0x25, 0x9c, 0xfe, 0x17, 0x00, 0x00, 0xff, 0xff, 0xd6, 0x71, 0x5a, 0xf8, 0x35, 0x08, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -837,28 +939,54 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if len(m.RelayVpnIp) > 0 { - dAtA3 := make([]byte, len(m.RelayVpnIp)*10) - var j2 int - for _, num := range m.RelayVpnIp { - for num >= 1<<7 { - dAtA3[j2] = uint8(uint64(num)&0x7f | 0x80) - num >>= 7 - j2++ + if len(m.RelayVpnAddrs) > 0 { + for iNdEx := len(m.RelayVpnAddrs) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.RelayVpnAddrs[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) } - dAtA3[j2] = uint8(num) - j2++ + i-- + dAtA[i] = 0x3a } - i -= j2 - copy(dAtA[i:], dAtA3[:j2]) - i = encodeVarintNebula(dAtA, i, uint64(j2)) + } + if m.VpnAddr != nil { + { + size, err := m.VpnAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if len(m.OldRelayVpnAddrs) > 0 { + dAtA4 := make([]byte, len(m.OldRelayVpnAddrs)*10) + var j3 int + for _, num := range m.OldRelayVpnAddrs { + for num >= 1<<7 { + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j3++ + } + dAtA4[j3] = uint8(num) + j3++ + } + i -= j3 + copy(dAtA[i:], dAtA4[:j3]) + i = encodeVarintNebula(dAtA, i, uint64(j3)) i-- dAtA[i] = 0x2a } - if len(m.Ip6AndPorts) > 0 { - for iNdEx := len(m.Ip6AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V6AddrPorts) > 0 { + for iNdEx := len(m.V6AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip6AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V6AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -874,10 +1002,10 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x18 } - if len(m.Ip4AndPorts) > 0 { - for iNdEx := len(m.Ip4AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V4AddrPorts) > 0 { + for iNdEx := len(m.V4AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip4AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V4AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -888,15 +1016,15 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { dAtA[i] = 0x12 } } - if m.VpnIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldVpnAddr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { +func (m *Addr) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -906,12 +1034,45 @@ func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip4AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *Addr) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *Addr) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Lo != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Lo)) + i-- + dAtA[i] = 0x10 + } + if m.Hi != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Hi)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *V4AddrPort) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *V4AddrPort) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *V4AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -921,15 +1082,15 @@ func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x10 } - if m.Ip != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Ip)) + if m.Addr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Addr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { +func (m *V6AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -939,12 +1100,12 @@ func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip6AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip6AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -1115,6 +1276,11 @@ func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) _ = i var l int _ = l + if m.CertVersion != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) + i-- + dAtA[i] = 0x40 + } if m.ResponderMultiPort != nil { { size, err := m.ResponderMultiPort.MarshalToSizedBuffer(dAtA[:i]) @@ -1189,13 +1355,37 @@ func (m *NebulaControl) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.RelayFromIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayFromIp)) + if m.RelayFromAddr != nil { + { + size, err := m.RelayFromAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x3a + } + if m.RelayToAddr != nil { + { + size, err := m.RelayToAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if m.OldRelayFromAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayFromAddr)) i-- dAtA[i] = 0x28 } - if m.RelayToIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayToAddr)) i-- dAtA[i] = 0x20 } @@ -1250,11 +1440,11 @@ func (m *NebulaMetaDetails) Size() (n int) { } var l int _ = l - if m.VpnIp != 0 { - n += 1 + sovNebula(uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + n += 1 + sovNebula(uint64(m.OldVpnAddr)) } - if len(m.Ip4AndPorts) > 0 { - for _, e := range m.Ip4AndPorts { + if len(m.V4AddrPorts) > 0 { + for _, e := range m.V4AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } @@ -1262,30 +1452,55 @@ func (m *NebulaMetaDetails) Size() (n int) { if m.Counter != 0 { n += 1 + sovNebula(uint64(m.Counter)) } - if len(m.Ip6AndPorts) > 0 { - for _, e := range m.Ip6AndPorts { + if len(m.V6AddrPorts) > 0 { + for _, e := range m.V6AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } } - if len(m.RelayVpnIp) > 0 { + if len(m.OldRelayVpnAddrs) > 0 { l = 0 - for _, e := range m.RelayVpnIp { + for _, e := range m.OldRelayVpnAddrs { l += sovNebula(uint64(e)) } n += 1 + sovNebula(uint64(l)) + l } + if m.VpnAddr != nil { + l = m.VpnAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if len(m.RelayVpnAddrs) > 0 { + for _, e := range m.RelayVpnAddrs { + l = e.Size() + n += 1 + l + sovNebula(uint64(l)) + } + } return n } -func (m *Ip4AndPort) Size() (n int) { +func (m *Addr) Size() (n int) { if m == nil { return 0 } var l int _ = l - if m.Ip != 0 { - n += 1 + sovNebula(uint64(m.Ip)) + if m.Hi != 0 { + n += 1 + sovNebula(uint64(m.Hi)) + } + if m.Lo != 0 { + n += 1 + sovNebula(uint64(m.Lo)) + } + return n +} + +func (m *V4AddrPort) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Addr != 0 { + n += 1 + sovNebula(uint64(m.Addr)) } if m.Port != 0 { n += 1 + sovNebula(uint64(m.Port)) @@ -1293,7 +1508,7 @@ func (m *Ip4AndPort) Size() (n int) { return n } -func (m *Ip6AndPort) Size() (n int) { +func (m *V6AddrPort) Size() (n int) { if m == nil { return 0 } @@ -1394,6 +1609,9 @@ func (m *NebulaHandshakeDetails) Size() (n int) { l = m.ResponderMultiPort.Size() n += 1 + l + sovNebula(uint64(l)) } + if m.CertVersion != 0 { + n += 1 + sovNebula(uint64(m.CertVersion)) + } return n } @@ -1412,11 +1630,19 @@ func (m *NebulaControl) Size() (n int) { if m.ResponderRelayIndex != 0 { n += 1 + sovNebula(uint64(m.ResponderRelayIndex)) } - if m.RelayToIp != 0 { - n += 1 + sovNebula(uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayToAddr)) } - if m.RelayFromIp != 0 { - n += 1 + sovNebula(uint64(m.RelayFromIp)) + if m.OldRelayFromAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayFromAddr)) + } + if m.RelayToAddr != nil { + l = m.RelayToAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if m.RelayFromAddr != nil { + l = m.RelayFromAddr.Size() + n += 1 + l + sovNebula(uint64(l)) } return n } @@ -1563,9 +1789,9 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field VpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldVpnAddr", wireType) } - m.VpnIp = 0 + m.OldVpnAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1575,14 +1801,14 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.VpnIp |= uint32(b&0x7F) << shift + m.OldVpnAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip4AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V4AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1609,8 +1835,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip4AndPorts = append(m.Ip4AndPorts, &Ip4AndPort{}) - if err := m.Ip4AndPorts[len(m.Ip4AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V4AddrPorts = append(m.V4AddrPorts, &V4AddrPort{}) + if err := m.V4AddrPorts[len(m.V4AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1635,7 +1861,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } case 4: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip6AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V6AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1662,8 +1888,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip6AndPorts = append(m.Ip6AndPorts, &Ip6AndPort{}) - if err := m.Ip6AndPorts[len(m.Ip6AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V6AddrPorts = append(m.V6AddrPorts, &V6AddrPort{}) + if err := m.V6AddrPorts[len(m.V6AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1684,7 +1910,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } else if wireType == 2 { var packedLen int for shift := uint(0); ; shift += 7 { @@ -1719,8 +1945,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } } elementCount = count - if elementCount != 0 && len(m.RelayVpnIp) == 0 { - m.RelayVpnIp = make([]uint32, 0, elementCount) + if elementCount != 0 && len(m.OldRelayVpnAddrs) == 0 { + m.OldRelayVpnAddrs = make([]uint32, 0, elementCount) } for iNdEx < postIndex { var v uint32 @@ -1738,11 +1964,81 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } } else { - return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayVpnAddrs", wireType) } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field VpnAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.VpnAddr == nil { + m.VpnAddr = &Addr{} + } + if err := m.VpnAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnAddrs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.RelayVpnAddrs = append(m.RelayVpnAddrs, &Addr{}) + if err := m.RelayVpnAddrs[len(m.RelayVpnAddrs)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -1764,7 +2060,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { +func (m *Addr) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1787,17 +2083,17 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip4AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: Addr: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip4AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: Addr: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Hi", wireType) } - m.Ip = 0 + m.Hi = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1807,7 +2103,95 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.Ip |= uint32(b&0x7F) << shift + m.Hi |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Lo", wireType) + } + m.Lo = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Lo |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipNebula(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthNebula + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *V4AddrPort) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: V4AddrPort: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: V4AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Addr", wireType) + } + m.Addr = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Addr |= uint32(b&0x7F) << shift if b < 0x80 { break } @@ -1852,7 +2236,7 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { +func (m *V6AddrPort) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1875,10 +2259,10 @@ func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip6AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: V6AddrPort: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip6AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: V6AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -2506,6 +2890,25 @@ func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex + case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) + } + m.CertVersion = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.CertVersion |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -2615,9 +3018,9 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } case 4: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayToIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayToAddr", wireType) } - m.RelayToIp = 0 + m.OldRelayToAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2627,16 +3030,16 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayToIp |= uint32(b&0x7F) << shift + m.OldRelayToAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 5: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayFromIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayFromAddr", wireType) } - m.RelayFromIp = 0 + m.OldRelayFromAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2646,11 +3049,83 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayFromIp |= uint32(b&0x7F) << shift + m.OldRelayFromAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayToAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayToAddr == nil { + m.RelayToAddr = &Addr{} + } + if err := m.RelayToAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayFromAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayFromAddr == nil { + m.RelayFromAddr = &Addr{} + } + if err := m.RelayFromAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) diff --git a/nebula.proto b/nebula.proto index a7928bd..6123f63 100644 --- a/nebula.proto +++ b/nebula.proto @@ -23,19 +23,28 @@ message NebulaMeta { } message NebulaMetaDetails { - uint32 VpnIp = 1; - repeated Ip4AndPort Ip4AndPorts = 2; - repeated Ip6AndPort Ip6AndPorts = 4; - repeated uint32 RelayVpnIp = 5; + uint32 OldVpnAddr = 1 [deprecated = true]; + Addr VpnAddr = 6; + + repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true]; + repeated Addr RelayVpnAddrs = 7; + + repeated V4AddrPort V4AddrPorts = 2; + repeated V6AddrPort V6AddrPorts = 4; uint32 counter = 3; } -message Ip4AndPort { - uint32 Ip = 1; +message Addr { + uint64 Hi = 1; + uint64 Lo = 2; +} + +message V4AddrPort { + uint32 Addr = 1; uint32 Port = 2; } -message Ip6AndPort { +message V6AddrPort { uint64 Hi = 1; uint64 Lo = 2; uint32 Port = 3; @@ -69,6 +78,7 @@ message NebulaHandshakeDetails { uint32 ResponderIndex = 3; uint64 Cookie = 4; uint64 Time = 5; + uint32 CertVersion = 8; MultiPortDetails InitiatorMultiPort = 6; MultiPortDetails ResponderMultiPort = 7; @@ -84,6 +94,10 @@ message NebulaControl { uint32 InitiatorRelayIndex = 2; uint32 ResponderRelayIndex = 3; - uint32 RelayToIp = 4; - uint32 RelayFromIp = 5; + + uint32 OldRelayToAddr = 4 [deprecated = true]; + uint32 OldRelayFromAddr = 5 [deprecated = true]; + + Addr RelayToAddr = 6; + Addr RelayFromAddr = 7; } diff --git a/noiseutil/pkcs11.go b/noiseutil/pkcs11.go new file mode 100644 index 0000000..d1c7ba9 --- /dev/null +++ b/noiseutil/pkcs11.go @@ -0,0 +1,50 @@ +package noiseutil + +import ( + "crypto/ecdh" + "fmt" + "strings" + + "github.com/slackhq/nebula/pkclient" + + "github.com/flynn/noise" +) + +// DHP256PKCS11 is the NIST P-256 ECDH function +var DHP256PKCS11 noise.DHFunc = newNISTP11Curve("P256", ecdh.P256(), 32) + +type nistP11Curve struct { + nistCurve +} + +func newNISTP11Curve(name string, curve ecdh.Curve, byteLen int) nistP11Curve { + return nistP11Curve{ + newNISTCurve(name, curve, byteLen), + } +} + +func (c nistP11Curve) DH(privkey, pubkey []byte) ([]byte, error) { + //for this function "privkey" is actually a pkcs11 URI + pkStr := string(privkey) + + //to set up a handshake, we need to also do non-pkcs11-DH. Handle that here. + if !strings.HasPrefix(pkStr, "pkcs11:") { + return DHP256.DH(privkey, pubkey) + } + ecdhPubKey, err := c.curve.NewPublicKey(pubkey) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) + } + + //this is not the most performant way to do this (a long-lived client would be better) + //but, it works, and helps avoid problems with stale sessions and HSMs used by multiple users. + client, err := pkclient.FromUrl(pkStr) + if err != nil { + return nil, err + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(client) + + return client.DeriveNoise(ecdhPubKey.Bytes()) +} diff --git a/outside.go b/outside.go index 538e3e8..5235371 100644 --- a/outside.go +++ b/outside.go @@ -3,46 +3,25 @@ package nebula import ( "encoding/binary" "errors" - "fmt" "net/netip" "time" - "github.com/flynn/noise" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv6" + "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" - "google.golang.org/protobuf/proto" ) const ( minFwPacketLen = 4 ) -// TODO: IPV6-WORK this can likely be removed now -func readOutsidePackets(f *Interface) udp.EncReader { - return func( - addr netip.AddrPort, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh udp.LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, - ) { - f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) - } -} - -func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { - // TODO: best if we return this and let caller log - // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) @@ -52,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] //l.Error("in packet ", header, packet[HeaderLen:]) if ip.IsValid() { - if f.myVpnNet.Contains(ip.Addr()) { + _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) + if found { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } @@ -109,7 +89,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") return } @@ -121,9 +101,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return case ForwardingType: // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") return } @@ -139,7 +119,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") return } } @@ -156,13 +136,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") - - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(net.Addr(addr), header.RemoteIndex) return } - lhf(ip, hostinfo.vpnIp, d) + lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic @@ -177,9 +154,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt test packet") - - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(net.Addr(addr), header.RemoteIndex) return } @@ -229,14 +203,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] Error("Failed to decrypt Control packet") return } - m := &NebulaControl{} - err = m.Unmarshal(d) - if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message") - break - } - f.relayManager.HandleControlMsg(hostinfo, m, f) + f.relayManager.HandleControlMsg(hostinfo, d, f) default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) @@ -253,8 +221,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) if final { - // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage - f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) + // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage + f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs) } } @@ -263,35 +231,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { - if ip.IsValid() && hostinfo.remote != ip { +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { + if udpAddr.IsValid() && hostinfo.remote != udpAddr { if hostinfo.multiportRx { // If the remote is sending with multiport, we aren't roaming unless // the IP has changed - if hostinfo.remote.Addr().Compare(ip.Addr()) == 0 { + if hostinfo.remote.Addr().Compare(udpAddr.Addr()) == 0 { return } // Keep the port from the original hostinfo, because the remote is transmitting from multiport ports - ip = netip.AddrPortFrom(ip.Addr(), hostinfo.remote.Port()) + udpAddr = netip.AddrPortFrom(udpAddr.Addr(), hostinfo.remote.Port()) } - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + + if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(ip) + hostinfo.SetRemote(udpAddr) } } @@ -311,24 +280,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h return true } +var ( + ErrPacketTooShort = errors.New("packet is too short") + ErrUnknownIPVersion = errors.New("packet is an unknown ip version") + ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length") + ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short") + ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short") + ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet") +) + // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { - // Do we at least have an ipv4 header worth of data? - if len(data) < ipv4.HeaderLen { - return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) + if len(data) < 1 { + return ErrPacketTooShort } - // Is it an ipv4 packet? - if int((data[0]>>4)&0x0f) != 4 { - return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) + version := int((data[0] >> 4) & 0x0f) + switch version { + case ipv4.Version: + return parseV4(data, incoming, fp) + case ipv6.Version: + return parseV6(data, incoming, fp) + } + return ErrUnknownIPVersion +} + +func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { + dataLen := len(data) + if dataLen < ipv6.HeaderLen { + return ErrIPv6PacketTooShort + } + + if incoming { + fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40]) + } else { + fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40]) + } + + protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header + offset := ipv6.HeaderLen // Start at the end of the ipv6 header + next := 0 + for { + if dataLen < offset { + break + } + + proto := layers.IPProtocol(data[protoAt]) + //fmt.Println(proto, protoAt) + switch proto { + case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: + fp.Protocol = uint8(proto) + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = false + return nil + + case layers.IPProtocolTCP, layers.IPProtocolUDP: + if dataLen < offset+4 { + return ErrIPv6PacketTooShort + } + + fp.Protocol = uint8(proto) + if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } + + fp.Fragment = false + return nil + + case layers.IPProtocolIPv6Fragment: + // Fragment header is 8 bytes, need at least offset+4 to read the offset field + if dataLen < offset+8 { + return ErrIPv6PacketTooShort + } + + // Check if this is the first fragment + fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits + if fragmentOffset != 0 { + // Non-first fragment, use what we have now and stop processing + fp.Protocol = data[offset] + fp.Fragment = true + fp.RemotePort = 0 + fp.LocalPort = 0 + return nil + } + + // The next loop should be the transport layer since we are the first fragment + next = 8 // Fragment headers are always 8 bytes + + case layers.IPProtocolAH: + // Auth headers, used by IPSec, have a different meaning for header length + if dataLen < offset+1 { + break + } + + next = int(data[offset+1]+2) << 2 + + default: + // Normal ipv6 header length processing + if dataLen < offset+1 { + break + } + + next = int(data[offset+1]+1) << 3 + } + + if next <= 0 { + // Safety check, each ipv6 header has to be at least 8 bytes + next = 8 + } + + protoAt = offset + offset = offset + next + } + + return ErrIPv6CouldNotFindPayload +} + +func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { + // Do we at least have an ipv4 header worth of data? + if len(data) < ipv4.HeaderLen { + return ErrIPv4PacketTooShort } // Adjust our start position based on the advertised ip header length ihl := int(data[0]&0x0f) << 2 - // Well formed ip header length? + // Well-formed ip header length? if ihl < ipv4.HeaderLen { - return fmt.Errorf("packet had an invalid header length: %v", ihl) + return ErrIPv4InvalidHeaderLength } // Check if this is the second or further fragment of a fragmented packet. @@ -344,14 +430,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { minLen += minFwPacketLen } if len(data) < minLen { - return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) + return ErrIPv4InvalidHeaderLength } // Firewall packets are locally oriented if incoming { - //TODO: IPV6-WORK - fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) - fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -360,9 +445,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - //TODO: IPV6-WORK - fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) - fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -397,8 +481,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(hostinfo.remote, header.RemoteIndex) return false } @@ -445,9 +527,8 @@ func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) - //TODO: this should be a signed message so we can trust that we should drop the index b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) - f.outside.WriteTo(b, endpoint) + _ = f.outside.WriteTo(b, endpoint) if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", index). WithField("udpAddr", endpoint). @@ -481,65 +562,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } - -/* -func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) { - if ci.eKey != nil { - //TODO: log error? - return - } - - msg, err := proto.Marshal(meta) - if err != nil { - l.Debugln("failed to encode header") - } - - c := ci.messageCounter - b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c) - ci.messageCounter++ - - msg := ci.eKey.EncryptDanger(b, nil, msg, c) - //msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c) - f.outside.WriteTo(msg, endpoint) -} -*/ - -func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) { - pk := h.PeerStatic() - - if pk == nil { - return nil, errors.New("no peer static key was present") - } - - if rawCertBytes == nil { - return nil, errors.New("provided payload was empty") - } - - r := &cert.RawNebulaCertificate{} - err := proto.Unmarshal(rawCertBytes, r) - if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %s", err) - } - - // If the Details are nil, just exit to avoid crashing - if r.Details == nil { - return nil, fmt.Errorf("certificate did not contain any details") - } - - r.Details.PublicKey = pk - recombined, err := proto.Marshal(r) - if err != nil { - return nil, fmt.Errorf("error while recombining certificate: %s", err) - } - - c, _ := cert.UnmarshalNebulaCertificate(recombined) - isValid, err := c.Verify(time.Now(), caPool) - if err != nil { - return c, fmt.Errorf("certificate validation failed: %s", err) - } else if !isValid { - // This case should never happen but here's to defensive programming! - return c, errors.New("certificate validation failed but did not return an error") - } - - return c, nil -} diff --git a/outside_test.go b/outside_test.go index f9d4bfa..f197594 100644 --- a/outside_test.go +++ b/outside_test.go @@ -1,10 +1,15 @@ package nebula import ( + "bytes" + "encoding/binary" "net" "net/netip" "testing" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" @@ -13,9 +18,15 @@ import ( func Test_newPacket(t *testing.T) { p := &firewall.Packet{} - // length fail - err := newPacket([]byte{0, 1}, true, p) - assert.EqualError(t, err, "packet is less than 20 bytes") + // length fails + err := newPacket([]byte{}, true, p) + assert.ErrorIs(t, err, ErrPacketTooShort) + + err = newPacket([]byte{0x40}, true, p) + assert.ErrorIs(t, err, ErrIPv4PacketTooShort) + + err = newPacket([]byte{0x60}, true, p) + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ @@ -28,16 +39,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - - assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24") + assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet is not ipv4, type: 0") + assert.ErrorIs(t, err, ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet had an invalid header length: 8") + assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ @@ -54,11 +64,12 @@ func Test_newPacket(t *testing.T) { err = newPacket(b, true, p) assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) - assert.Equal(t, p.RemotePort, uint16(3)) - assert.Equal(t, p.LocalPort, uint16(4)) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) + assert.Equal(t, uint16(3), p.RemotePort) + assert.Equal(t, uint16(4), p.LocalPort) + assert.False(t, p.Fragment) // account for variable ip header length - outgoing h = ipv4.Header{ @@ -75,9 +86,506 @@ func Test_newPacket(t *testing.T) { err = newPacket(b, false, p) assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) - assert.Equal(t, p.RemotePort, uint16(6)) - assert.Equal(t, p.LocalPort, uint16(5)) + assert.Equal(t, uint8(2), p.Protocol) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) + assert.Equal(t, uint16(6), p.RemotePort) + assert.Equal(t, uint16(5), p.LocalPort) + assert.False(t, p.Fragment) +} + +func Test_newPacket_v6(t *testing.T) { + p := &firewall.Packet{} + + // invalid ipv6 + ip := layers.IPv6{ + Version: 6, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: false, + } + err := gopacket.SerializeLayers(buffer, opt, &ip) + assert.NoError(t, err) + + err = newPacket(buffer.Bytes(), true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A good ICMP packet + ip = layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolICMPv6, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + icmp := layers.ICMPv6{} + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp) + if err != nil { + panic(err) + } + + err = newPacket(buffer.Bytes(), true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // A good ESP packet + b := buffer.Bytes() + b[6] = byte(layers.IPProtocolESP) + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // A good None packet + b = buffer.Bytes() + b[6] = byte(layers.IPProtocolNoNextHeader) + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // An unknown protocol packet + b = buffer.Bytes() + b[6] = 255 // 255 is a reserved protocol number + err = newPacket(b, true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A good UDP packet + ip = layers.IPv6{ + Version: 6, + NextHeader: firewall.ProtoUDP, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + err = udp.SetNetworkLayerForChecksum(&ip) + assert.NoError(t, err) + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) + if err != nil { + panic(err) + } + b = buffer.Bytes() + + // incoming + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // outgoing + err = newPacket(b, false, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Too short UDP packet + err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + + // A good TCP packet + b[6] = byte(layers.IPProtocolTCP) + + // incoming + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // outgoing + err = newPacket(b, false, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Too short TCP packet + err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + + // A good UDP packet with an AH header + ip = layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolAH, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + ah := layers.IPSecAH{ + AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef}, + } + ah.NextHeader = layers.IPProtocolUDP + + udpHeader := []byte{ + 0x8d, 0x1b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer.Clear() + err = ip.SerializeTo(buffer, opt) + if err != nil { + panic(err) + } + + b = buffer.Bytes() + ahb := serializeAH(&ah) + b = append(b, ahb...) + b = append(b, udpHeader...) + + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // Invalid AH header + b = buffer.Bytes() + err = newPacket(b, true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) +} + +func Test_newPacket_ipv6Fragment(t *testing.T) { + p := &firewall.Packet{} + + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6Fragment, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + // First fragment + fragHeader1 := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0x00, // Fragment Offset high byte (0) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + udpHeader := []byte{ + 0x8d, 0x1b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err := ip.SerializeTo(buffer, opts) + if err != nil { + t.Fatal(err) + } + + firstFrag := buffer.Bytes() + firstFrag = append(firstFrag, fragHeader1...) + firstFrag = append(firstFrag, udpHeader...) + firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Test first fragment incoming + err = newPacket(firstFrag, true, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // Test first fragment outgoing + err = newPacket(firstFrag, false, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Second fragment + fragHeader2 := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0xb9, // Fragment Offset high byte (185) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + buffer.Clear() + err = ip.SerializeTo(buffer, opts) + if err != nil { + t.Fatal(err) + } + + secondFrag := buffer.Bytes() + secondFrag = append(secondFrag, fragHeader2...) + secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Test second fragment incoming + err = newPacket(secondFrag, true, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.True(t, p.Fragment) + + // Test second fragment outgoing + err = newPacket(secondFrag, false, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(0), p.LocalPort) + assert.Equal(t, uint16(0), p.RemotePort) + assert.True(t, p.Fragment) + + // Too short of a fragment packet + err = newPacket(secondFrag[:len(secondFrag)-10], false, p) + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) +} + +func BenchmarkParseV6(b *testing.B) { + // Regular UDP packet + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := &layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + + buffer := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: true, + } + + err := gopacket.SerializeLayers(buffer, opts, ip, udp) + if err != nil { + b.Fatal(err) + } + normalPacket := buffer.Bytes() + + // First Fragment packet + ipFrag := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6Fragment, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + fragHeader := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0x00, // Fragment Offset high byte (0) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + udpHeader := []byte{ + 0x8d, 0x7b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer.Clear() + err = ipFrag.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + firstFrag := buffer.Bytes() + firstFrag = append(firstFrag, fragHeader...) + firstFrag = append(firstFrag, udpHeader...) + firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Second Fragment packet + fragHeader[2] = 0xb9 // offset 185 + buffer.Clear() + err = ipFrag.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + secondFrag := buffer.Bytes() + secondFrag = append(secondFrag, fragHeader...) + secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + fp := &firewall.Packet{} + + b.Run("Normal", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(normalPacket, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("FirstFragment", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(firstFrag, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("SecondFragment", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(secondFrag, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + // Evil packet + evilPacket := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6HopByHop, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + hopHeader := []byte{ + uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop) + 0x00, // Length + 0x00, 0x00, // Options and padding + 0x00, 0x00, 0x00, 0x00, // More options and padding + } + + lastHopHeader := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Length + 0x00, 0x00, // Options and padding + 0x00, 0x00, 0x00, 0x00, // More options and padding + } + + buffer.Clear() + err = evilPacket.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + evilBytes := buffer.Bytes() + for i := 0; i < 200; i++ { + evilBytes = append(evilBytes, hopHeader...) + } + evilBytes = append(evilBytes, lastHopHeader...) + evilBytes = append(evilBytes, udpHeader...) + evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...) + + b.Run("200 HopByHop headers", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(evilBytes, false, fp); err != nil { + b.Fatal(err) + } + } + }) +} + +// Ensure authentication data is a multiple of 8 bytes by padding if necessary +func padAuthData(authData []byte) []byte { + // Length of Authentication Data must be a multiple of 8 bytes + paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary + if paddingLength > 0 { + authData = append(authData, make([]byte, paddingLength)...) + } + return authData +} + +// Custom function to manually serialize IPSecAH for both IPv4 and IPv6 +func serializeAH(ah *layers.IPSecAH) []byte { + buf := new(bytes.Buffer) + + // Ensure Authentication Data is a multiple of 8 bytes + ah.AuthenticationData = padAuthData(ah.AuthenticationData) + // Calculate Payload Length (in 32-bit words, minus 2) + payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2 + + // Serialize fields + if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil { + panic(err) + } + if len(ah.AuthenticationData) > 0 { + if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil { + panic(err) + } + } + + return buf.Bytes() } diff --git a/overlay/device.go b/overlay/device.go index 50ad6ad..da8cbe9 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -8,7 +8,7 @@ import ( type Device interface { io.ReadWriteCloser Activate() error - Cidr() netip.Prefix + Networks() []netip.Prefix Name() string RouteFor(netip.Addr) netip.Addr NewMultiQueueReader() (io.ReadWriteCloser, error) diff --git a/overlay/route.go b/overlay/route.go index 8ccc994..687cc11 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table return routeTree, nil } -func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { +func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { + found := false + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() { + found = true + break + } + } + + if !found { return nil, fmt.Errorf( - "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", + "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v", i+1, r.Cidr.String(), - network.String(), + networks, ) } @@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if network.Contains(r.Cidr.Addr()) { - return nil, fmt.Errorf( - "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", - i+1, - r.Cidr.String(), - network.String(), - ) + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) { + return nil, fmt.Errorf( + "entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v", + i+1, + r.Cidr.String(), + network.String(), + ) + } } routes[i] = r diff --git a/overlay/route_test.go b/overlay/route_test.go index d791389..c60e4c2 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) { assert.NoError(t, err) // test no routes config - routes, err := parseRoutes(c, n) + routes, err := parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") + + // Not in multiple ranges + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} + routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 2) @@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.NoError(t, err) // test no routes config - routes, err := parseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") @@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) { 127, false, nil, 1.0, []string{"1", "2"}, } { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) } // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Nil(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Nil(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") @@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 4) @@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) { map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} - routes, err := parseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) assert.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) diff --git a/overlay/tun.go b/overlay/tun.go index 12460da..4a6377d 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -11,36 +11,36 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): - tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) + tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, tunCidr, routines > 1) + return newTun(c, l, vpnNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, tunCidr) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks) } } -func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { +func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil } - routes, err := parseRoutes(c, cidr) + routes, err := parseRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) } - unsafeRoutes, err := parseUnsafeRoutes(c, cidr) + unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 98ad9b4..72a6565 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -18,14 +18,14 @@ import ( type tun struct { io.ReadWriteCloser - fd int - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix t := &tun{ ReadWriteCloser: file, fd: deviceFd, - cidr: cidr, + vpnNetworks: vpnNetworks, l: l, } @@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } @@ -66,7 +66,7 @@ func (t tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 0b573e6..1a02b49 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -24,56 +24,62 @@ import ( type tun struct { io.ReadWriteCloser - Device string - cidr netip.Prefix - DefaultMTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - linkAddr *netroute.LinkAddr - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + DefaultMTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + linkAddr *netroute.LinkAddr + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - type ifReq struct { - Name [16]byte + Name [unix.IFNAMSIZ]byte Flags uint16 pad [8]byte } -var sockaddrCtlSize uintptr = 32 - const ( - _SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */ - _AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */ - _PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM - _CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info) - utunControlName = "com.apple.net.utun_control" + _SIOCAIFADDR_IN6 = 2155899162 + _UTUN_OPT_IFNAME = 2 + _IN6_IFF_NODAD = 0x0020 + _IN6_IFF_SECURED = 0x0400 + utunControlName = "com.apple.net.utun_control" ) -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 pad [8]byte } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +type addrLifetime struct { + Expire float64 + Preferred float64 + Vltime uint32 + Pltime uint32 +} + +type ifreqAlias4 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet4 + DstAddr unix.RawSockaddrInet4 + MaskAddr unix.RawSockaddrInet4 +} + +type ifreqAlias6 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet6 + DstAddr unix.RawSockaddrInet6 + PrefixMask unix.RawSockaddrInet6 + Flags uint32 + Lifetime addrLifetime +} + +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } } - fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL) + fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL) if err != nil { return nil, fmt.Errorf("system socket: %v", err) } - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} + var ctlInfo = &unix.CtlInfo{} + copy(ctlInfo.Name[:], utunControlName) - copy(ctlInfo.ctlName[:], utunControlName) - - err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo))) + err = unix.IoctlCtlInfo(fd, ctlInfo) if err != nil { return nil, fmt.Errorf("CTLIOCGINFO: %v", err) } - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: _AF_SYS_CONTROL, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, + err = unix.Connect(fd, &unix.SockaddrCtl{ + ID: ctlInfo.Id, + Unit: uint32(ifIndex) + 1, + }) + if err != nil { + return nil, fmt.Errorf("SYS_CONNECT: %v", err) } - _, _, errno := unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(unsafe.Pointer(&sc)), - sockaddrCtlSize, - ) - if errno != 0 { - return nil, fmt.Errorf("SYS_CONNECT: %v", errno) + name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME) + if err != nil { + return nil, fmt.Errorf("failed to retrieve tun name: %w", err) } - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(len(ifName.name)) - _, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), - 2, // SYSPROTO_CONTROL - 2, // UTUN_OPT_IFNAME - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) - if errno != 0 { - return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno) - } - name = string(ifName.name[:ifNameSize-1]) - - err = syscall.SetNonblock(fd, true) + err = unix.SetNonblock(fd, true) if err != nil { return nil, fmt.Errorf("SetNonblock: %v", err) } - file := os.NewFile(uintptr(fd), "") - t := &tun{ - ReadWriteCloser: file, + ReadWriteCloser: os.NewFile(uintptr(fd), ""), Device: name, - cidr: cidr, + vpnNetworks: vpnNetworks, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -186,16 +167,6 @@ func (t *tun) Close() error { func (t *tun) Activate() error { devName := t.deviceBytes() - var addr, mask [4]byte - - if !t.cidr.Addr().Is4() { - //TODO: IPV6-WORK - panic("need ipv6") - } - - addr = t.cidr.Addr().As4() - copy(mask[:], prefixToMask(t.cidr)) - s, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM, @@ -208,66 +179,18 @@ func (t *tun) Activate() error { fd := uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - - // Set the device name - ifrf := ifReq{Name: devName} - if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to set tun device name: %s", err) - } - // Set the MTU on the device ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)} if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { return fmt.Errorf("failed to set tun mtu: %v", err) } - /* - // Set the transmit queue length - ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} - if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { - // If we can't set the queue length nebula will still work but it may lead to packet loss - l.WithError(err).Error("Failed to set tun tx queue length") - } - */ - - // Bring up the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to bring the tun device up: %s", err) + // Get the device flags + ifrf := ifReq{Name: devName} + if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to get tun flags: %s", err) } - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} linkAddr, err := getLinkAddr(t.Device) if err != nil { return err @@ -277,14 +200,18 @@ func (t *tun) Activate() error { } t.linkAddr = linkAddr - copy(routeAddr.IP[:], addr[:]) - copy(maskAddr.IP[:], mask[:]) - err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) - if err != nil { - if errors.Is(err, unix.EEXIST) { - err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr) + for _, network := range t.vpnNetworks { + if network.Addr().Is4() { + err = t.activate4(network) + if err != nil { + return err + } + } else { + err = t.activate6(network) + if err != nil { + return err + } } - return err } // Run the interface @@ -297,8 +224,89 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) activate4(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias4{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + DstAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + MaskAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: prefixToMask(network).As4(), + }, + } + + if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun v4 address: %s", err) + } + + err = addRoute(network, t.linkAddr) + if err != nil { + return err + } + + return nil +} + +func (t *tun) activate6(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias6{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: network.Addr().As16(), + }, + PrefixMask: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: prefixToMask(network).As16(), + }, + Lifetime: addrLifetime{ + // never expires + Vltime: 0xffffffff, + Pltime: 0xffffffff, + }, + //TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address? + Flags: _IN6_IFF_NODAD, + } + + if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun address: %s", err) + } + + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -343,7 +351,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { } // Get the LinkAddr for the interface of the given name -// TODO: Is there an easier way to fetch this when we create the interface? +// Is there an easier way to fetch this when we create the interface? // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers. func getLinkAddr(name string) (*netroute.LinkAddr, error) { rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) @@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) { } func (t *tun) addRoutes(logErrors bool) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} routes := *t.Routes.Load() + for _, r := range routes { if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - if !r.Cidr.Addr().Is4() { - //TODO: implement ipv6 - panic("Cant handle ipv6 routes yet") - } - - routeAddr.IP = r.Cidr.Addr().As4() - //TODO: we could avoid the copy - copy(maskAddr.IP[:], prefixToMask(r.Cidr)) - - err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { t.l.WithField("route", r.Cidr). @@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error { } func (t *tun) removeRoutes(routes []Route) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} - for _, r := range routes { if !r.Install { continue } - if r.Cidr.Addr().Is6() { - //TODO: implement ipv6 - panic("Cant handle ipv6 routes yet") - } - - routeAddr.IP = r.Cidr.Addr().As4() - copy(maskAddr.IP[:], prefixToMask(r.Cidr)) - - err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := delRoute(r.Cidr, t.linkAddr) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } + _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) @@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) return nil } -func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } @@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) } func (t *tun) Read(to []byte) (int, error) { - buf := make([]byte, len(to)+4) n, err := t.ReadWriteCloser.Read(buf) @@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } -func prefixToMask(prefix netip.Prefix) []byte { +func prefixToMask(prefix netip.Prefix) netip.Addr { pLen := 128 if prefix.Addr().Is4() { pLen = 32 } - return net.CIDRMask(prefix.Bits(), pLen) + + addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) + return addr } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 130f8f9..cfbf17d 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -12,8 +12,8 @@ import ( ) type disabledTun struct { - read chan []byte - cidr netip.Prefix + read chan []byte + vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter @@ -21,11 +21,11 @@ type disabledTun struct { l *logrus.Logger } -func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ - cidr: cidr, - read: make(chan []byte, queueLen), - l: l, + vpnNetworks: vpnNetworks, + read: make(chan []byte, queueLen), + l: l, } if metricsEnabled { @@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { return netip.Addr{} } -func (t *disabledTun) Cidr() netip.Prefix { - return t.cidr +func (t *disabledTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (*disabledTun) Name() string { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index bdfeb58..69690e9 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -46,12 +46,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -78,11 +78,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -195,8 +195,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 20981f0..e99d447 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -21,20 +21,20 @@ import ( type tun struct { io.ReadWriteCloser - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - cidr: cidr, + vpnNetworks: vpnNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } @@ -59,7 +59,7 @@ func (t *tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0e7e20d..993bd4a 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -11,6 +11,7 @@ import ( "os" "strings" "sync/atomic" + "time" "unsafe" "github.com/gaissmai/bart" @@ -25,7 +26,7 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr netip.Prefix + vpnNetworks []netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int @@ -40,18 +41,16 @@ type tun struct { l *logrus.Logger } +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks +} + type ifReq struct { Name [16]byte Flags uint16 pad [8]byte } -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 @@ -64,10 +63,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), - cidr: cidr, + vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), l: l, @@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref } func (t *tun) reload(c *config.C, initial bool) error { - routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -190,11 +189,13 @@ func (t *tun) reload(c *config.C, initial bool) error { } if oldDefaultMTU != newDefaultMTU { - err := t.setDefaultRoute() - if err != nil { - t.l.Warn(err) - } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + for i := range t.vpnNetworks { + err := t.setDefaultRoute(t.vpnNetworks[i]) + if err != nil { + t.l.Warn(err) + } else { + t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + } } } @@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) Write(b []byte) (int, error) { var nn int - max := len(b) + maximum := len(b) for { - n, err := unix.Write(t.fd, b[nn:max]) + n, err := unix.Write(t.fd, b[nn:maximum]) if n > 0 { nn += n } @@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) { return } +func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool { + for i := range al { + if al[i].Equal(x) { + return true + } + } + return false +} + +// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there +func (t *tun) addIPs(link netlink.Link) error { + newAddrs := make([]*netlink.Addr, len(t.vpnNetworks)) + for i := range t.vpnNetworks { + newAddrs[i] = &netlink.Addr{ + IPNet: &net.IPNet{ + IP: t.vpnNetworks[i].Addr().AsSlice(), + Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()), + }, + Label: t.vpnNetworks[i].Addr().Zone(), + } + } + + //add all new addresses + for i := range newAddrs { + //TODO: CERT-V2 do we want to stack errors and try as many ops as possible? + //AddrReplace still adds new IPs, but if their properties change it will change them as well + if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { + return err + } + } + + //iterate over remainder, remove whoever shouldn't be there + al, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to get tun address list: %s", err) + } + + for i := range al { + if hasNetlinkAddr(newAddrs, al[i]) { + continue + } + err = netlink.AddrDel(link, &al[i]) + if err != nil { + t.l.WithError(err).Error("failed to remove address from tun address list") + } else { + t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + } + } + + return nil +} + func (t *tun) Activate() error { devName := t.deviceBytes() @@ -272,15 +325,8 @@ func (t *tun) Activate() error { t.watchRoutes() } - var addr, mask [4]byte - - //TODO: IPV6-WORK - addr = t.cidr.Addr().As4() - tmask := net.CIDRMask(t.cidr.Bits(), 32) - copy(mask[:], tmask) - s, err := unix.Socket( - unix.AF_INET, + unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine unix.SOCK_DGRAM, unix.IPPROTO_IP, ) @@ -289,31 +335,19 @@ func (t *tun) Activate() error { } t.ioctlFd = uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - // Set the device name ifrf := ifReq{Name: devName} if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } + link, err := netlink.LinkByName(t.Device) + if err != nil { + return fmt.Errorf("failed to get tun device link: %s", err) + } + + t.deviceIndex = link.Attrs().Index + // Setup our default MTU t.setMTU() @@ -324,20 +358,21 @@ func (t *tun) Activate() error { t.l.WithError(err).Error("Failed to set tun tx queue length") } + if err = t.addIPs(link); err != nil { + return err + } + // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } - link, err := netlink.LinkByName(t.Device) - if err != nil { - return fmt.Errorf("failed to get tun device link: %s", err) - } - t.deviceIndex = link.Attrs().Index - - if err = t.setDefaultRoute(); err != nil { - return err + //set route MTU + for i := range t.vpnNetworks { + if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { + return fmt.Errorf("failed to set default route MTU: %w", err) + } } // Set the routes @@ -363,12 +398,10 @@ func (t *tun) setMTU() { } } -func (t *tun) setDefaultRoute() error { - // Default route - +func (t *tun) setDefaultRoute(cidr netip.Prefix) error { dr := &net.IPNet{ - IP: t.cidr.Masked().Addr().AsSlice(), - Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + IP: cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()), } nr := netlink.Route{ @@ -377,14 +410,27 @@ func (t *tun) setDefaultRoute() error { MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: net.IP(t.cidr.Addr().AsSlice()), + Src: net.IP(cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } err := netlink.RouteReplace(&nr) if err != nil { - return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) + t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") + //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` + for i := 0; i < 2; i++ { + time.Sleep(100 * time.Millisecond) + err = netlink.RouteReplace(&nr) + if err == nil { + break + } else { + t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") + } + } + if err != nil { + return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) + } } return nil @@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() netip.Prefix { - return t.cidr -} - func (t *tun) Name() string { return t.Device } @@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - //TODO: IPV6-WORK what if not ok? gwAddr, ok := netip.AddrFromSlice(r.Gw) if !ok { t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") @@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } gwAddr = gwAddr.Unmap() - if !t.cidr.Contains(gwAddr) { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our network") - return + withinNetworks := false + for i := range t.vpnNetworks { + if t.vpnNetworks[i].Contains(gwAddr) { + withinNetworks = true + break + } } - - if x := r.Dst.IP.To4(); x == nil { - // Nebula only handles ipv4 on the overlay currently - t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4") + if !withinNetworks { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") return } @@ -563,11 +605,11 @@ func (t *tun) Close() error { } if t.ReadWriteCloser != nil { - t.ReadWriteCloser.Close() + _ = t.ReadWriteCloser.Close() } if t.ioctlFd > 0 { - os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() } return nil diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 24ab24f..f7586cb 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -27,12 +27,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -58,13 +58,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -130,8 +130,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error { continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 6463ccb..a2fd184 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -21,12 +21,12 @@ import ( ) type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser @@ -42,13 +42,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -138,7 +138,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -148,6 +148,16 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) RouteFor(ip netip.Addr) netip.Addr { r, _ := t.routeTree.Load().Lookup(ip) return r @@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error { // We don't allow route MTUs so only install routes with a via continue } - - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error { if !r.Install { continue } - - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index ba15723..cc3942f 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -16,19 +16,19 @@ import ( ) type TestTun struct { - Device string - cidr netip.Prefix - Routes []Route - routeTree *bart.Table[netip.Addr] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + Routes []Route + routeTree *bart.Table[netip.Addr] + l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { - _, routes, err := getAllRoutesFromConfig(c, cidr, true) +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { + _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err } @@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, } return &TestTun{ - Device: c.GetString("tun.dev", ""), - cidr: cidr, - Routes: routes, - routeTree: routeTree, - l: l, - rxPackets: make(chan []byte, 10), - TxPackets: make(chan []byte, 10), + Device: c.GetString("tun.dev", ""), + vpnNetworks: vpnNetworks, + Routes: routes, + routeTree: routeTree, + l: l, + rxPackets: make(chan []byte, 10), + TxPackets: make(chan []byte, 10), }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -95,8 +95,8 @@ func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() netip.Prefix { - return t.cidr +func (t *TestTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *TestTun) Name() string { diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go deleted file mode 100644 index d78f564..0000000 --- a/overlay/tun_water_windows.go +++ /dev/null @@ -1,208 +0,0 @@ -package overlay - -import ( - "fmt" - "io" - "net" - "net/netip" - "os/exec" - "strconv" - "sync/atomic" - - "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" - "github.com/songgao/water" -) - -type waterTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger - f *net.Interface - *water.Interface -} - -func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { - // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() - t := &waterTun{ - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - } - - err := t.reload(c, true) - if err != nil { - return nil, err - } - - c.RegisterReloadCallback(func(c *config.C) { - err := t.reload(c, false) - if err != nil { - util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) - } - }) - - return t, nil -} - -func (t *waterTun) Activate() error { - var err error - t.Interface, err = water.New(water.Config{ - DeviceType: water.TUN, - PlatformSpecificParams: water.PlatformSpecificParams{ - ComponentID: "tap0901", - Network: t.cidr.String(), - }, - }) - if err != nil { - return fmt.Errorf("activate failed: %v", err) - } - - t.Device = t.Interface.Name() - - // TODO use syscalls instead of exec.Command - err = exec.Command( - `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", - fmt.Sprintf("name=%s", t.Device), - "source=static", - fmt.Sprintf("addr=%s", t.cidr.Addr()), - fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), - "gateway=none", - ).Run() - if err != nil { - return fmt.Errorf("failed to run 'netsh' to set address: %s", err) - } - err = exec.Command( - `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface", - t.Device, - fmt.Sprintf("mtu=%d", t.MTU), - ).Run() - if err != nil { - return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) - } - - t.f, err = net.InterfaceByName(t.Device) - if err != nil { - return fmt.Errorf("failed to find interface named %s: %v", t.Device, err) - } - - err = t.addRoutes(false) - if err != nil { - return err - } - - return nil -} - -func (t *waterTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) - if err != nil { - return err - } - - if !initial && !change { - return nil - } - - routeTree, err := makeRouteTree(t.l, routes, false) - if err != nil { - return err - } - - // Teach nebula how to handle the routes before establishing them in the system table - oldRoutes := t.Routes.Swap(&routes) - t.routeTree.Store(routeTree) - - if !initial { - // Remove first, if the system removes a wanted route hopefully it will be re-added next - t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) - - // Ensure any routes we actually want are installed - err = t.addRoutes(true) - if err != nil { - // Catch any stray logs - util.LogWithContextIfNeeded("Failed to set routes", err, t.l) - } else { - for _, r := range findRemovedRoutes(routes, *oldRoutes) { - t.l.WithField("route", r).Info("Removed route") - } - } - } - - return nil -} - -func (t *waterTun) addRoutes(logErrors bool) error { - // Path routes - routes := *t.Routes.Load() - for _, r := range routes { - if !r.Via.IsValid() || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } - - err := exec.Command( - "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), - ).Run() - - if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - } else { - return retErr - } - } else { - t.l.WithField("route", r).Info("Added route") - } - } - - return nil -} - -func (t *waterTun) removeRoutes(routes []Route) { - for _, r := range routes { - if !r.Install { - continue - } - - err := exec.Command( - "C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), - ).Run() - if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") - } else { - t.l.WithField("route", r).Info("Removed route") - } - } -} - -func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { - r, _ := t.routeTree.Load().Lookup(ip) - return r -} - -func (t *waterTun) Cidr() netip.Prefix { - return t.cidr -} - -func (t *waterTun) Name() string { - return t.Device -} - -func (t *waterTun) Close() error { - if t.Interface == nil { - return nil - } - - return t.Interface.Close() -} - -func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") -} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 3d88309..289999d 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -4,41 +4,268 @@ package overlay import ( + "crypto" "fmt" + "io" "net/netip" "os" "path/filepath" "runtime" + "sync/atomic" "syscall" + "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/wintun" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { +const tunGUIDLabel = "Fixed Nebula Windows GUID v1" + +type winTun struct { + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger + + tun *wintun.NativeTun +} + +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { - useWintun := true - if err := checkWinTunExists(); err != nil { - l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") - useWintun = false - } - - if useWintun { - device, err := newWinTun(c, l, cidr, multiqueue) - if err != nil { - return nil, fmt.Errorf("create Wintun interface failed, %w", err) - } - return device, nil - } - - device, err := newWaterTun(c, l, cidr, multiqueue) +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { + err := checkWinTunExists() if err != nil { - return nil, fmt.Errorf("create wintap driver failed, %w", err) + return nil, fmt.Errorf("can not load the wintun driver: %w", err) } - return device, nil + + deviceName := c.GetString("tun.dev", "") + guid, err := generateGUIDByDeviceName(deviceName) + if err != nil { + return nil, fmt.Errorf("generate GUID failed: %w", err) + } + + t := &winTun{ + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) + if err != nil { + return nil, err + } + + var tunDevice wintun.Device + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) + if err != nil { + // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. + // Trying a second time resolves the issue. + l.WithError(err).Debug("Failed to create wintun device, retrying") + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) + if err != nil { + return nil, fmt.Errorf("create TUN device failed: %w", err) + } + } + t.tun = tunDevice.(*wintun.NativeTun) + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} + +func (t *winTun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil +} + +func (t *winTun) Activate() error { + luid := winipcfg.LUID(t.tun.LUID()) + + err := luid.SetIPAddresses(t.vpnNetworks) + if err != nil { + return fmt.Errorf("failed to set address: %w", err) + } + + err = t.addRoutes(false) + if err != nil { + return err + } + + return nil +} + +func (t *winTun) addRoutes(logErrors bool) error { + luid := winipcfg.LUID(t.tun.LUID()) + routes := *t.Routes.Load() + foundDefault4 := false + + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + // Add our unsafe route + err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + continue + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") + } + + if !foundDefault4 { + if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { + foundDefault4 = true + } + } + } + + ipif, err := luid.IPInterface(windows.AF_INET) + if err != nil { + return fmt.Errorf("failed to get ip interface: %w", err) + } + + ipif.NLMTU = uint32(t.MTU) + if foundDefault4 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + + if err := ipif.Set(); err != nil { + return fmt.Errorf("failed to set ip interface: %w", err) + } + return nil +} + +func (t *winTun) removeRoutes(routes []Route) error { + luid := winipcfg.LUID(t.tun.LUID()) + + for _, r := range routes { + if !r.Install { + continue + } + + err := luid.DeleteRoute(r.Cidr, r.Via) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + +func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) + return r +} + +func (t *winTun) Networks() []netip.Prefix { + return t.vpnNetworks +} + +func (t *winTun) Name() string { + return t.Device +} + +func (t *winTun) Read(b []byte) (int, error) { + return t.tun.Read(b, 0) +} + +func (t *winTun) Write(b []byte) (int, error) { + return t.tun.Write(b, 0) +} + +func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") +} + +func (t *winTun) Close() error { + // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, + // so to be certain, just remove everything before destroying. + luid := winipcfg.LUID(t.tun.LUID()) + _ = luid.FlushRoutes(windows.AF_INET) + _ = luid.FlushIPAddresses(windows.AF_INET) + + _ = luid.FlushRoutes(windows.AF_INET6) + _ = luid.FlushIPAddresses(windows.AF_INET6) + + _ = luid.FlushDNS(windows.AF_INET) + _ = luid.FlushDNS(windows.AF_INET6) + + return t.tun.Close() +} + +func generateGUIDByDeviceName(name string) (*windows.GUID, error) { + // GUID is 128 bit + hash := crypto.MD5.New() + + _, err := hash.Write([]byte(tunGUIDLabel)) + if err != nil { + return nil, err + } + + _, err = hash.Write([]byte(name)) + if err != nil { + return nil, err + } + + sum := hash.Sum(nil) + + return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } func checkWinTunExists() error { diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go deleted file mode 100644 index d010387..0000000 --- a/overlay/tun_wintun_windows.go +++ /dev/null @@ -1,252 +0,0 @@ -package overlay - -import ( - "crypto" - "fmt" - "io" - "net/netip" - "sync/atomic" - "unsafe" - - "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" - "github.com/slackhq/nebula/wintun" - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -const tunGUIDLabel = "Fixed Nebula Windows GUID v1" - -type winTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger - - tun *wintun.NativeTun -} - -func generateGUIDByDeviceName(name string) (*windows.GUID, error) { - // GUID is 128 bit - hash := crypto.MD5.New() - - _, err := hash.Write([]byte(tunGUIDLabel)) - if err != nil { - return nil, err - } - - _, err = hash.Write([]byte(name)) - if err != nil { - return nil, err - } - - sum := hash.Sum(nil) - - return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil -} - -func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { - deviceName := c.GetString("tun.dev", "") - guid, err := generateGUIDByDeviceName(deviceName) - if err != nil { - return nil, fmt.Errorf("generate GUID failed: %w", err) - } - - t := &winTun{ - Device: deviceName, - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - } - - err = t.reload(c, true) - if err != nil { - return nil, err - } - - var tunDevice wintun.Device - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) - if err != nil { - // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. - // Trying a second time resolves the issue. - l.WithError(err).Debug("Failed to create wintun device, retrying") - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) - if err != nil { - return nil, fmt.Errorf("create TUN device failed: %w", err) - } - } - t.tun = tunDevice.(*wintun.NativeTun) - - c.RegisterReloadCallback(func(c *config.C) { - err := t.reload(c, false) - if err != nil { - util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) - } - }) - - return t, nil -} - -func (t *winTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) - if err != nil { - return err - } - - if !initial && !change { - return nil - } - - routeTree, err := makeRouteTree(t.l, routes, false) - if err != nil { - return err - } - - // Teach nebula how to handle the routes before establishing them in the system table - oldRoutes := t.Routes.Swap(&routes) - t.routeTree.Store(routeTree) - - if !initial { - // Remove first, if the system removes a wanted route hopefully it will be re-added next - err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) - if err != nil { - util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) - } - - // Ensure any routes we actually want are installed - err = t.addRoutes(true) - if err != nil { - // Catch any stray logs - util.LogWithContextIfNeeded("Failed to add routes", err, t.l) - } - } - - return nil -} - -func (t *winTun) Activate() error { - luid := winipcfg.LUID(t.tun.LUID()) - - err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) - if err != nil { - return fmt.Errorf("failed to set address: %w", err) - } - - err = t.addRoutes(false) - if err != nil { - return err - } - - return nil -} - -func (t *winTun) addRoutes(logErrors bool) error { - luid := winipcfg.LUID(t.tun.LUID()) - routes := *t.Routes.Load() - foundDefault4 := false - - for _, r := range routes { - if !r.Via.IsValid() || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } - - // Add our unsafe route - err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) - if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - continue - } else { - return retErr - } - } else { - t.l.WithField("route", r).Info("Added route") - } - - if !foundDefault4 { - if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { - foundDefault4 = true - } - } - } - - ipif, err := luid.IPInterface(windows.AF_INET) - if err != nil { - return fmt.Errorf("failed to get ip interface: %w", err) - } - - ipif.NLMTU = uint32(t.MTU) - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - - if err := ipif.Set(); err != nil { - return fmt.Errorf("failed to set ip interface: %w", err) - } - return nil -} - -func (t *winTun) removeRoutes(routes []Route) error { - luid := winipcfg.LUID(t.tun.LUID()) - - for _, r := range routes { - if !r.Install { - continue - } - - err := luid.DeleteRoute(r.Cidr, r.Via) - if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") - } else { - t.l.WithField("route", r).Info("Removed route") - } - } - return nil -} - -func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { - r, _ := t.routeTree.Load().Lookup(ip) - return r -} - -func (t *winTun) Cidr() netip.Prefix { - return t.cidr -} - -func (t *winTun) Name() string { - return t.Device -} - -func (t *winTun) Read(b []byte) (int, error) { - return t.tun.Read(b, 0) -} - -func (t *winTun) Write(b []byte) (int, error) { - return t.tun.Write(b, 0) -} - -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") -} - -func (t *winTun) Close() error { - // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, - // so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(t.tun.LUID()) - _ = luid.FlushRoutes(windows.AF_INET) - _ = luid.FlushIPAddresses(windows.AF_INET) - /* We don't support IPV6 yet - _ = luid.FlushRoutes(windows.AF_INET6) - _ = luid.FlushIPAddresses(windows.AF_INET6) - */ - _ = luid.FlushDNS(windows.AF_INET) - - return t.tun.Close() -} diff --git a/overlay/user.go b/overlay/user.go index 1bb4ef5..ae665f3 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -8,16 +8,16 @@ import ( "github.com/slackhq/nebula/config" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return NewUserDevice(tunCidr) +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return NewUserDevice(vpnNetworks) } -func NewUserDevice(tunCidr netip.Prefix) (Device, error) { +func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() return &UserDevice{ - tunCidr: tunCidr, + vpnNetworks: vpnNetworks, outboundReader: or, outboundWriter: ow, inboundReader: ir, @@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) { } type UserDevice struct { - tunCidr netip.Prefix + vpnNetworks []netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -38,7 +38,7 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { diff --git a/pkclient/pkclient.go b/pkclient/pkclient.go new file mode 100644 index 0000000..7061de6 --- /dev/null +++ b/pkclient/pkclient.go @@ -0,0 +1,87 @@ +package pkclient + +import ( + "crypto/ecdsa" + "crypto/x509" + "fmt" + "io" + "strconv" + + "github.com/stefanberger/go-pkcs11uri" +) + +type Client interface { + io.Closer + GetPubKey() ([]byte, error) + DeriveNoise(peerPubKey []byte) ([]byte, error) + Test() error +} + +const NoiseKeySize = 32 + +func FromUrl(pkurl string) (*PKClient, error) { + uri := pkcs11uri.New() + uri.SetAllowAnyModule(true) //todo + err := uri.Parse(pkurl) + if err != nil { + return nil, err + } + + module, err := uri.GetModule() + if err != nil { + return nil, err + } + + slotid := 0 + slot, ok := uri.GetPathAttribute("slot-id", false) + if !ok { + slotid = 0 + } else { + slotid, err = strconv.Atoi(slot) + if err != nil { + return nil, err + } + } + + pin, _ := uri.GetPIN() + id, _ := uri.GetPathAttribute("id", false) + label, _ := uri.GetPathAttribute("object", false) + + return New(module, uint(slotid), pin, id, label) +} + +func ecKeyToArray(key *ecdsa.PublicKey) []byte { + x := make([]byte, 32) + y := make([]byte, 32) + key.X.FillBytes(x) + key.Y.FillBytes(y) + return append([]byte{0x04}, append(x, y...)...) +} + +func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) { + e, err := x509.ParsePKIXPublicKey(d) + if err != nil { + return nil, err + } + switch t := e.(type) { + case *ecdsa.PublicKey: + return ecKeyToArray(e.(*ecdsa.PublicKey)), nil + default: + return nil, fmt.Errorf("unknown public key type: %T", t) + } +} + +func (c *PKClient) Test() error { + pub, err := c.GetPubKey() + if err != nil { + return fmt.Errorf("failed to get public key: %w", err) + } + out, err := c.DeriveNoise(pub) //do an ECDH with ourselves as a quick test + if err != nil { + return err + } + if len(out) != NoiseKeySize { + return fmt.Errorf("got a key of %d bytes, expected %d", len(out), NoiseKeySize) + } + return nil +} diff --git a/pkclient/pkclient_cgo.go b/pkclient/pkclient_cgo.go new file mode 100644 index 0000000..a2ead55 --- /dev/null +++ b/pkclient/pkclient_cgo.go @@ -0,0 +1,229 @@ +//go:build cgo && pkcs11 + +package pkclient + +import ( + "encoding/asn1" + "errors" + "fmt" + "log" + "math/big" + + "github.com/miekg/pkcs11" + "github.com/miekg/pkcs11/p11" +) + +type PKClient struct { + module p11.Module + session p11.Session + id []byte + label []byte + privKeyObj p11.Object + pubKeyObj p11.Object +} + +type ecdsaSignature struct { + R, S *big.Int +} + +// New tries to open a session with the HSM, select the slot and login to it +func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { + module, err := p11.OpenModule(hsmPath) + if err != nil { + return nil, fmt.Errorf("failed to load module library: %s", hsmPath) + } + + slots, err := module.Slots() + if err != nil { + module.Destroy() + return nil, err + } + + // Try to open a session on the slot + slotIdx := 0 + for i, slot := range slots { + if slot.ID() == slotId { + slotIdx = i + break + } + } + + client := &PKClient{ + module: module, + id: []byte(id), + label: []byte(label), + } + + client.session, err = slots[slotIdx].OpenWriteSession() + if err != nil { + module.Destroy() + return nil, fmt.Errorf("failed to open session on slot %d", slotId) + } + + if len(pin) != 0 { + err = client.session.Login(pin) + if err != nil { + // ignore "already logged in" + if !errors.Is(err, pkcs11.Error(256)) { + _ = client.session.Close() + return nil, fmt.Errorf("unable to login. error: %w", err) + } + } + } + + // Make sure the hsm has a private key for deriving + client.privKeyObj, err = client.findDeriveKey(client.id, client.label, true) + if err != nil { + _ = client.Close() //log out, close session, destroy module + return nil, fmt.Errorf("failed to find private key for deriving: %w", err) + } + + return client, nil +} + +// Close cleans up properly and logs out +func (c *PKClient) Close() error { + var err error = nil + if c.session != nil { + _ = c.session.Logout() //if logout fails, we still want to close + err = c.session.Close() + } + + c.module.Destroy() + return err +} + +// Try to find a suitable key on the hsm for key derivation +// parameter GET_PUB_KEY sets the search pattern for a public or private key +func (c *PKClient) findDeriveKey(id []byte, label []byte, private bool) (key p11.Object, err error) { + keyClass := pkcs11.CKO_PRIVATE_KEY + if !private { + keyClass = pkcs11.CKO_PUBLIC_KEY + } + keyAttrs := []*pkcs11.Attribute{ + //todo, not all HSMs seem to report this, even if its true: pkcs11.NewAttribute(pkcs11.CKA_DERIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass), + } + + if id != nil && len(id) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + if label != nil && len(label) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + return c.session.FindObject(keyAttrs) +} + +func (c *PKClient) listDeriveKeys(id []byte, label []byte, private bool) { + keyClass := pkcs11.CKO_PRIVATE_KEY + if !private { + keyClass = pkcs11.CKO_PUBLIC_KEY + } + keyAttrs := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass), + } + + if id != nil && len(id) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + if label != nil && len(label) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + objects, err := c.session.FindObjects(keyAttrs) + if err != nil { + return + } + + for _, obj := range objects { + l, err := obj.Label() + log.Printf("%s, %v", l, err) + a, err := obj.Attribute(pkcs11.CKA_DERIVE) + log.Printf("DERIVE: %s %v, %v", l, a, err) + } +} + +// SignASN1 signs some data. Returns the ASN.1 encoded signature. +func (c *PKClient) SignASN1(data []byte) ([]byte, error) { + mech := pkcs11.NewMechanism(pkcs11.CKM_ECDSA_SHA256, nil) + sk := p11.PrivateKey(c.privKeyObj) + rawSig, err := sk.Sign(*mech, data) + if err != nil { + return nil, err + } + + // PKCS #11 Mechanisms v2.30: + // "The signature octets correspond to the concatenation of the ECDSA values r and s, + // both represented as an octet string of equal length of at most nLen with the most + // significant byte first. If r and s have different octet length, the shorter of both + // must be padded with leading zero octets such that both have the same octet length. + // Loosely spoken, the first half of the signature is r and the second half is s." + r := new(big.Int).SetBytes(rawSig[:len(rawSig)/2]) + s := new(big.Int).SetBytes(rawSig[len(rawSig)/2:]) + return asn1.Marshal(ecdsaSignature{r, s}) +} + +// DeriveNoise derives a shared secret using the input public key against the private key that was found during setup. +// Returns a fixed 32 byte array. +func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) { + // Before we call derive, we need to have an array of attributes which specify the type of + // key to be returned, in our case, it's the shared secret key, produced via deriving + // This template pulled from OpenSC pkclient-tool.c line 4038 + attrTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, false), + pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY), + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_GENERIC_SECRET), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, false), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, true), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_WRAP, true), + pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true), + } + + // Set up the parameters which include the peer's public key + ecdhParams := pkcs11.NewECDH1DeriveParams(pkcs11.CKD_NULL, nil, peerPubKey) + mech := pkcs11.NewMechanism(pkcs11.CKM_ECDH1_DERIVE, ecdhParams) + sk := p11.PrivateKey(c.privKeyObj) + + tmpKey, err := sk.Derive(*mech, attrTemplate) + if err != nil { + return nil, err + } + if tmpKey == nil || len(tmpKey) == 0 { + return nil, fmt.Errorf("got an empty secret key") + } + secret := make([]byte, NoiseKeySize) + copy(secret[:], tmpKey[:NoiseKeySize]) + return secret, nil +} + +func (c *PKClient) GetPubKey() ([]byte, error) { + d, err := c.privKeyObj.Attribute(pkcs11.CKA_PUBLIC_KEY_INFO) + if err != nil { + return nil, err + } + if d != nil && len(d) > 0 { + return formatPubkeyFromPublicKeyInfoAttr(d) + } + c.pubKeyObj, err = c.findDeriveKey(c.id, c.label, false) + if err != nil { + return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and looking up the public key also failed: %w", err) + } + d, err = c.pubKeyObj.Attribute(pkcs11.CKA_EC_POINT) + if err != nil { + return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and reading CKA_EC_POINT also failed: %w", err) + } + if d == nil || len(d) < 1 { + return nil, fmt.Errorf("pkcs11 module gave us a nil or empty CKA_EC_POINT") + } + switch len(d) { + case 65: //length of 0x04 + len(X) + len(Y) + return d, nil + case 67: //as above, DER-encoded IIRC? + return d[2:], nil + default: + return nil, fmt.Errorf("unknown public key length: %d", len(d)) + } +} diff --git a/pkclient/pkclient_stub.go b/pkclient/pkclient_stub.go new file mode 100644 index 0000000..36b0fc9 --- /dev/null +++ b/pkclient/pkclient_stub.go @@ -0,0 +1,30 @@ +//go:build !cgo || !pkcs11 + +package pkclient + +import "errors" + +type PKClient struct { +} + +var notImplemented = errors.New("not implemented") + +func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { + return nil, notImplemented +} + +func (c *PKClient) Close() error { + return nil +} + +func (c *PKClient) SignASN1(data []byte) ([]byte, error) { + return nil, notImplemented +} + +func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) { + return nil, notImplemented +} + +func (c *PKClient) GetPubKey() ([]byte, error) { + return nil, notImplemented +} diff --git a/pki.go b/pki.go index ab95a04..888da7c 100644 --- a/pki.go +++ b/pki.go @@ -1,13 +1,19 @@ package nebula import ( + "encoding/binary" + "encoding/json" "errors" "fmt" + "net" + "net/netip" "os" + "slices" "strings" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" @@ -16,16 +22,27 @@ import ( type PKI struct { cs atomic.Pointer[CertState] - caPool atomic.Pointer[cert.NebulaCAPool] + caPool atomic.Pointer[cert.CAPool] l *logrus.Logger } type CertState struct { - Certificate *cert.NebulaCertificate - RawCertificate []byte - RawCertificateNoKey []byte - PublicKey []byte - PrivateKey []byte + v1Cert cert.Certificate + v1HandshakeBytes []byte + + v2Cert cert.Certificate + v2HandshakeBytes []byte + + defaultVersion cert.Version + privateKey []byte + pkcs11Backed bool + cipher string + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr + myVpnAddrsTable *bart.Table[struct{}] + myVpnBroadcastAddrsTable *bart.Table[struct{}] } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -45,16 +62,16 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { return pki, nil } -func (p *PKI) GetCertState() *CertState { - return p.cs.Load() -} - -func (p *PKI) GetCAPool() *cert.NebulaCAPool { +func (p *PKI) GetCAPool() *cert.CAPool { return p.caPool.Load() } +func (p *PKI) getCertState() *CertState { + return p.cs.Load() +} + func (p *PKI) reload(c *config.C, initial bool) error { - err := p.reloadCert(c, initial) + err := p.reloadCerts(c, initial) if err != nil { if initial { return err @@ -73,33 +90,94 @@ func (p *PKI) reload(c *config.C, initial bool) error { return nil } -func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { - cs, err := newCertStateFromConfig(c) +func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { + newState, err := newCertStateFromConfig(c) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } if !initial { - //TODO: include check for mask equality as well + currentState := p.cs.Load() + if newState.v1Cert != nil { + if currentState.v1Cert == nil { + return util.NewContextualError("v1 certificate was added, restart required", nil, err) + } - // did IP in cert change? if so, don't set - currentCert := p.cs.Load().Certificate - oldIPs := currentCert.Details.Ips - newIPs := cs.Certificate.Details.Ips - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()}, + nil, + ) + } + + if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()}, + nil, + ) + } + + } else if currentState.v1Cert != nil { + //TODO: CERT-V2 we should be able to tear this down + return util.NewContextualError("v1 certificate was removed, restart required", nil, err) + } + + if newState.v2Cert != nil { + if currentState.v2Cert == nil { + return util.NewContextualError("v2 certificate was added, restart required", nil, err) + } + + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()}, + nil, + ) + } + + if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()}, + nil, + ) + } + + } else if currentState.v2Cert != nil { + return util.NewContextualError("v2 certificate was removed, restart required", nil, err) + } + + // Cipher cant be hot swapped so just leave it at what it was before + newState.cipher = currentState.cipher + + } else { + newState.cipher = c.GetString("cipher", "aes") + //TODO: this sucks and we should make it not a global + switch newState.cipher { + case "aes": + noiseEndianness = binary.BigEndian + case "chachapoly": + noiseEndianness = binary.LittleEndian + default: return util.NewContextualError( - "IP in new cert was different from old", - m{"new_ip": newIPs[0], "old_ip": oldIPs[0]}, + "unknown cipher", + m{"cipher": newState.cipher}, nil, ) } } - p.cs.Store(cs) + p.cs.Store(newState) + + //TODO: CERT-V2 newState needs a stringer that does json if initial { - p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") } else { - p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") } return nil } @@ -115,34 +193,68 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { return nil } -func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) +func (cs *CertState) GetDefaultCertificate() cert.Certificate { + c := cs.getCertificate(cs.defaultVersion) + if c == nil { + panic("No default certificate found") + } + return c +} + +func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { + switch v { + case cert.Version1: + return cs.v1Cert + case cert.Version2: + return cs.v2Cert } - publicKey := certificate.Details.PublicKey - cs := &CertState{ - RawCertificate: rawCertificate, - Certificate: certificate, - PrivateKey: privateKey, - PublicKey: publicKey, + return nil +} + +// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. +// Callers must check if the return []byte is nil. +func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { + switch v { + case cert.Version1: + return cs.v1HandshakeBytes + case cert.Version2: + return cs.v2HandshakeBytes + default: + return nil + } +} + +func (cs *CertState) String() string { + b, err := cs.MarshalJSON() + if err != nil { + return fmt.Sprintf("error marshaling certificate state: %v", err) + } + return string(b) +} + +func (cs *CertState) MarshalJSON() ([]byte, error) { + msg := []json.RawMessage{} + if cs.v1Cert != nil { + b, err := cs.v1Cert.MarshalJSON() + if err != nil { + return nil, err + } + msg = append(msg, b) } - cs.Certificate.Details.PublicKey = nil - rawCertNoKey, err := cs.Certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + if cs.v2Cert != nil { + b, err := cs.v2Cert.MarshalJSON() + if err != nil { + return nil, err + } + msg = append(msg, b) } - cs.RawCertificateNoKey = rawCertNoKey - // put public key back - cs.Certificate.Details.PublicKey = cs.PublicKey - return cs, nil + + return json.Marshal(msg) } func newCertStateFromConfig(c *config.C) (*CertState, error) { - var pemPrivateKey []byte var err error privPathOrPEM := c.GetString("pki.key", "") @@ -150,20 +262,9 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, errors.New("no pki.key path or PEM data provided") } - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - - } else { - pemPrivateKey, err = os.ReadFile(privPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - } - - rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) + rawKey, curve, isPkcs11, err := loadPrivateKey(privPathOrPEM) if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + return nil, err } var rawCert []byte @@ -184,27 +285,200 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } } - nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + var crt, v1, v2 cert.Certificate + for { + // Load the certificate + crt, rawCert, err = loadCertificate(rawCert) + if err != nil { + return nil, err + } + + switch crt.Version() { + case cert.Version1: + if v1 != nil { + return nil, fmt.Errorf("v1 certificate already found in pki.cert") + } + v1 = crt + case cert.Version2: + if v2 != nil { + return nil, fmt.Errorf("v2 certificate already found in pki.cert") + } + v2 = crt + default: + return nil, fmt.Errorf("unknown certificate version %v", crt.Version()) + } + + if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { + break + } } - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") + if v1 == nil && v2 == nil { + return nil, errors.New("no certificates found in pki.cert") } - if len(nebulaCert.Details.Ips) == 0 { - return nil, fmt.Errorf("no IPs encoded in certificate") + useDefaultVersion := uint32(1) + if v1 == nil { + // The only condition that requires v2 as the default is if only a v2 certificate is present + // We do this to avoid having to configure it specifically in the config file + useDefaultVersion = 2 } - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion) + var defaultVersion cert.Version + switch rawDefaultVersion { + case 1: + if v1 == nil { + return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert") + } + defaultVersion = cert.Version1 + case 2: + defaultVersion = cert.Version2 + default: + return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion) } - return newCertState(nebulaCert, rawKey) + return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey) } -func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { +func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { + cs := CertState{ + privateKey: privateKey, + pkcs11Backed: pkcs11backed, + myVpnNetworksTable: new(bart.Table[struct{}]), + myVpnAddrsTable: new(bart.Table[struct{}]), + myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), + } + + if v1 != nil && v2 != nil { + if !slices.Equal(v1.PublicKey(), v2.PublicKey()) { + return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil) + } + + if v1.Curve() != v2.Curve() { + return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) + } + + //TODO: CERT-V2 make sure v2 has v1s address + + cs.defaultVersion = dv + } + + if v1 != nil { + if pkcs11backed { + //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v1hs, err := v1.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v1Cert = v1 + cs.v1HandshakeBytes = v1hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version1 + } + } + + if v2 != nil { + if pkcs11backed { + //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v2hs, err := v2.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v2Cert = v2 + cs.v2HandshakeBytes = v2hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version2 + } + } + + var crt cert.Certificate + crt = cs.getCertificate(cert.Version2) + if crt == nil { + // v2 certificates are a superset, only look at v1 if its all we have + crt = cs.getCertificate(cert.Version1) + } + + for _, network := range crt.Networks() { + cs.myVpnNetworks = append(cs.myVpnNetworks, network) + cs.myVpnNetworksTable.Insert(network, struct{}{}) + + cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) + cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) + + if network.Addr().Is4() { + addr := network.Masked().Addr().As4() + mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) + cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) + } + } + + return &cs, nil +} + +func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { + var pemPrivateKey []byte + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { + rawKey = []byte(privPathOrPEM) + return rawKey, cert.Curve_P256, true, nil + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } + + return +} + +func loadCertificate(b []byte) (cert.Certificate, []byte, error) { + c, b, err := cert.UnmarshalCertificateFromPEM(b) + if err != nil { + return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err) + } + + if c.Expired(time.Now()) { + return nil, b, fmt.Errorf("nebula certificate for this host is expired") + } + + if len(c.Networks()) == 0 { + return nil, b, fmt.Errorf("no networks encoded in certificate") + } + + if c.IsCA() { + return nil, b, fmt.Errorf("host certificate is a CA certificate") + } + + return c, b, nil +} + +func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { var rawCA []byte var err error @@ -223,11 +497,11 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, er } } - caPool, err := cert.NewCAPoolFromBytes(rawCA) + caPool, err := cert.NewCAPoolFromPEM(rawCA) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { - if crt.Expired(time.Now()) { + if crt.Certificate.Expired(time.Now()) { expired++ l.WithField("cert", crt).Warn("expired certificate present in CA pool") } diff --git a/relay_manager.go b/relay_manager.go index 1a3a4d4..7565350 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti Type: relayType, State: state, LocalIndex: index, - PeerIp: vpnIp, + PeerAddr: vpnIp, } if remoteIdx != nil { @@ -91,40 +92,71 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp, + fields := logrus.Fields{ + "relay": relayHostInfo.vpnAddrs[0], "initiatorRelayIndex": m.InitiatorRelayIndex, - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp}).Info("relayManager failed to update relay") + } + + if m.RelayFromAddr == nil { + fields["relayFrom"] = m.OldRelayFromAddr + } else { + fields["relayFrom"] = m.RelayFromAddr + } + + if m.RelayToAddr == nil { + fields["relayTo"] = m.OldRelayToAddr + } else { + fields["relayTo"] = m.RelayToAddr + } + + rm.l.WithFields(fields).Info("relayManager failed to update relay") return nil, fmt.Errorf("unknown relay") } return relay, nil } -func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) { - - switch m.Type { - case NebulaControl_CreateRelayRequest: - rm.handleCreateRelayRequest(h, f, m) - case NebulaControl_CreateRelayResponse: - rm.handleCreateRelayResponse(h, f, m) +func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { + msg := &NebulaControl{} + err := msg.Unmarshal(d) + if err != nil { + h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + return } + var v cert.Version + if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 { + v = cert.Version1 + + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr) + msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + + binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr) + msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + } else { + v = cert.Version2 + } + + switch msg.Type { + case NebulaControl_CreateRelayRequest: + rm.handleCreateRelayRequest(v, h, f, msg) + case NebulaControl_CreateRelayResponse: + rm.handleCreateRelayResponse(v, h, f, msg) + } } -func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { +func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp, + "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), + "relayTo": protoAddrToNetAddr(m.RelayToAddr), "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("handleCreateRelayResponse") - target := m.RelayToIp - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], m.RelayToIp) - targetAddr := netip.AddrFrom4(b) + + target := m.RelayToAddr + targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -136,68 +168,88 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. - peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) + peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") + rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") + rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") return } - if peerRelay.State == PeerRequested { - //TODO: IPV6-WORK - b = peerHostInfo.vpnIp.As4() - peerRelay.State = Established + switch peerRelay.State { + case Requested: + // I initiated the request to this peer, but haven't heard back from the peer yet. I must wait for this peer + // to respond to complete the connection. + case PeerRequested, Disestablished, Established: + peerHostInfo.relayState.UpdateRelayForByIpState(targetAddr, Established) resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(b[:]), - RelayToIp: uint32(target), } + + if v == cert.Version1 { + peer := peerHostInfo.vpnAddrs[0] + if !peer.Is4() { + rm.l.WithField("relayFrom", peer). + WithField("relayTo", target). + WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). + WithField("responderRelayIndex", resp.ResponderRelayIndex). + WithField("vpnAddrs", peerHostInfo.vpnAddrs). + Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") + return + } + + b := peer.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = targetAddr.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0]) + resp.RelayToAddr = target + } + msg, err := resp.Marshal() if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.WithError(err). + Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": resp.RelayFromIp, - "relayTo": resp.RelayToIp, + "relayFrom": resp.RelayFromAddr, + "relayTo": resp.RelayToAddr, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": peerHostInfo.vpnIp}). + "vpnAddrs": peerHostInfo.vpnAddrs}). Info("send CreateRelayResponse") } } } -func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], m.RelayFromIp) - from := netip.AddrFrom4(b) - - binary.BigEndian.PutUint32(b[:], m.RelayToIp) - target := netip.AddrFrom4(b) +func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { + from := protoAddrToNetAddr(m.RelayFromAddr) + target := protoAddrToNetAddr(m.RelayToAddr) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, "relayTo": target, "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnIp": h.vpnIp}) + "vpnAddrs": h.vpnAddrs}) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnNet.Addr() { + _, found := f.myVpnAddrsTable.Lookup(from) + if found { logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } + // Is the target of the relay me? - if target == f.myVpnNet.Addr() { + _, found = f.myVpnAddrsTable.Lookup(target) + if found { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -215,6 +267,21 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } + case Disestablished: + if existingRelay.RemoteIndex != m.InitiatorRelayIndex { + // We got a brand new Relay request, because its index is different than what we saw before. + // This should never happen. The peer should never change an index, once created. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + return + } + // Mark the relay as 'Established' because it's safe to use again + h.relayState.UpdateRelayForByIpState(from, Established) + case PeerRequested: + // I should never be in this state, because I am terminal, not forwarding. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex, + "state": existingRelay.State}).Error("Unexpected Relay State found") } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) @@ -226,21 +293,26 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { - logMsg.Error("Relay State not found") + logMsg.WithField("from", from).Error("Relay State not found") return } - //TODO: IPV6-WORK - fromB := from.As4() - targetB := target.As4() - resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), } + + if v == cert.Version1 { + b := from.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(from) + resp.RelayToAddr = netAddrToProtoAddr(target) + } + msg, err := resp.Marshal() if err != nil { logMsg. @@ -248,12 +320,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now "relayFrom": from, "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("send CreateRelayResponse") } return @@ -262,7 +333,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N if !rm.GetAmRelay() { return } - peer := rm.hostmap.QueryVpnIp(target) + peer := rm.hostmap.QueryVpnAddr(target) if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! @@ -273,104 +344,69 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N // Only create relays to peers for whom I have a direct connection return } - sendCreateRequest := false var index uint32 var err error targetRelay, ok := peer.relayState.QueryRelayForByIp(from) if ok { index = targetRelay.LocalIndex - if targetRelay.State == Requested { - sendCreateRequest = true - } } else { // Allocate an index in the hostMap for this relay peer index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested) if err != nil { return } - sendCreateRequest = true } - if sendCreateRequest { - //TODO: IPV6-WORK - fromB := h.vpnIp.As4() - targetB := target.As4() + peer.relayState.UpdateRelayForByIpState(from, Requested) + // Send a CreateRelayRequest to the peer. + req := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: index, + } - // Send a CreateRelayRequest to the peer. - req := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: index, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), - } - msg, err := req.Marshal() - if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control message to create relay") - } else { - f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - //TODO: IPV6-WORK another lazy used to use the req object - "relayFrom": h.vpnIp, - "relayTo": target, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": target}). - Info("send CreateRelayRequest") + if v == cert.Version1 { + if !h.vpnAddrs[0].Is4() { + rm.l.WithField("relayFrom", h.vpnAddrs[0]). + WithField("relayTo", target). + WithField("initiatorRelayIndex", req.InitiatorRelayIndex). + WithField("responderRelayIndex", req.ResponderRelayIndex). + WithField("vpnAddr", target). + Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") + return } + + b := h.vpnAddrs[0].As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0]) + req.RelayToAddr = netAddrToProtoAddr(target) } + + msg, err := req.Marshal() + if err != nil { + logMsg. + WithError(err).Error("relayManager Failed to marshal Control message to create relay") + } else { + f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + "relayFrom": h.vpnAddrs[0], + "relayTo": target, + "initiatorRelayIndex": req.InitiatorRelayIndex, + "responderRelayIndex": req.ResponderRelayIndex, + "vpnAddr": target}). + Info("send CreateRelayRequest") + } + // Also track the half-created Relay state just received - relay, ok := h.relayState.QueryRelayForByIp(target) + _, ok = h.relayState.QueryRelayForByIp(target) if !ok { - // Add the relay - state := PeerRequested - if targetRelay != nil && targetRelay.State == Established { - state = Established - } - _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, state) + _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { logMsg. WithError(err).Error("relayManager Failed to allocate a local index for relay") return } - } else { - switch relay.State { - case Established: - if relay.RemoteIndex != m.InitiatorRelayIndex { - // We got a brand new Relay request, because its index is different than what we saw before. - // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") - return - } - //TODO: IPV6-WORK - fromB := h.vpnIp.As4() - targetB := target.As4() - resp := NebulaControl{ - Type: NebulaControl_CreateRelayResponse, - ResponderRelayIndex: relay.LocalIndex, - InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), - } - msg, err := resp.Marshal() - if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") - } else { - f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - //TODO: IPV6-WORK more lazy, used to use resp object - "relayFrom": h.vpnIp, - "relayTo": target, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). - Info("send CreateRelayResponse") - } - - case Requested: - // Keep waiting for the other relay to complete - } } } } diff --git a/remote_list.go b/remote_list.go index fa14f42..6baed29 100644 --- a/remote_list.go +++ b/remote_list.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "slices" "sort" "strconv" "sync" @@ -17,8 +18,8 @@ import ( type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -32,9 +33,6 @@ type Cache struct { Relay []netip.Addr `json:"relay"` } -//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion -// We will never clean learned/reported information for them as it stands today - // cache is an internal struct that splits v4 and v6 addresses inside the cache map type cache struct { v4 *cacheV4 @@ -48,14 +46,14 @@ type cacheRelay struct { // cacheV4 stores learned and reported ipv4 records under cache type cacheV4 struct { - learned *Ip4AndPort - reported []*Ip4AndPort + learned *V4AddrPort + reported []*V4AddrPort } // cacheV4 stores learned and reported ipv6 records under cache type cacheV6 struct { - learned *Ip6AndPort - reported []*Ip6AndPort + learned *V6AddrPort + reported []*V6AddrPort } type hostnamePort struct { @@ -170,7 +168,7 @@ func (hr *hostnamesResults) Cancel() { } } -func (hr *hostnamesResults) GetIPs() []netip.AddrPort { +func (hr *hostnamesResults) GetAddrs() []netip.AddrPort { var retSlice []netip.AddrPort if hr != nil { p := hr.ips.Load() @@ -189,6 +187,9 @@ type RemoteList struct { // Every interaction with internals requires a lock! sync.RWMutex + // The full list of vpn addresses assigned to this host + vpnAddrs []netip.Addr + // A deduplicated set of addresses. Any accessor should lock beforehand. addrs []netip.AddrPort @@ -212,13 +213,16 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { - return &RemoteList{ +func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { + r := &RemoteList{ + vpnAddrs: make([]netip.Addr, len(vpnAddrs)), addrs: make([]netip.AddrPort, 0), relays: make([]netip.Addr, 0), cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } + copy(r.vpnAddrs, vpnAddrs) + return r } func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { @@ -268,14 +272,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort // LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available -// TODO: this needs to support the allow list list func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() if remote.Addr().Is4() { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) + r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) + r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port())) } } @@ -304,21 +307,21 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) + c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) + c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) + c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) + c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a)) } } @@ -379,7 +382,6 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { defer r.Unlock() // Only rebuild if the cache changed - //TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in if r.shouldRebuild { r.unlockedCollect() r.shouldRebuild = false @@ -401,14 +403,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -423,7 +425,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPor } } -func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) { +func (r *RemoteList) unlockedSetRelay(ownerVpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) @@ -436,12 +438,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip4AndPort{to}, c.reported...) + c.reported = append([]*V4AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -449,14 +451,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -473,12 +475,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip6AndPort{to}, c.reported...) + c.reported = append([]*V6AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -536,14 +538,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := AddrPortFromIp4AndPort(c.v4.learned) + u := protoV4AddrPortToNetAddrPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := AddrPortFromIp4AndPort(v) + u := protoV4AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -552,14 +554,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := AddrPortFromIp6AndPort(c.v6.learned) + u := protoV6AddrPortToNetAddrPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := AddrPortFromIp6AndPort(v) + u := protoV6AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -573,10 +575,12 @@ func (r *RemoteList) unlockedCollect() { } } - dnsAddrs := r.hr.GetIPs() + dnsAddrs := r.hr.GetAddrs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { - addrs = append(addrs, addr) + if !r.unlockedIsBad(addr) { + addrs = append(addrs, addr) + } } } @@ -587,6 +591,21 @@ func (r *RemoteList) unlockedCollect() { // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { + // Use a map to deduplicate any relay addresses + dedupedRelays := map[netip.Addr]struct{}{} + for _, relay := range r.relays { + dedupedRelays[relay] = struct{}{} + } + r.relays = r.relays[:0] + for relay := range dedupedRelays { + r.relays = append(r.relays, relay) + } + // Put them in a somewhat consistent order after de-duplication + slices.SortFunc(r.relays, func(a, b netip.Addr) int { + return a.Compare(b) + }) + + // Now the addrs n := len(r.addrs) if n < 2 { return @@ -685,7 +704,6 @@ func minInt(a, b int) int { // isPreferred returns true of the ip is contained in the preferredRanges list func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { - //TODO: this would be better in a CIDR6Tree for _, p := range preferredRanges { if p.Contains(ip) { return true diff --git a/remote_list_test.go b/remote_list_test.go index 62a892b..0caf86a 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -9,11 +9,11 @@ import ( ) func TestRemoteList_Rebuild(t *testing.T) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), // this is duped newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), // this is duped @@ -25,20 +25,30 @@ func TestRemoteList_Rebuild(t *testing.T) { newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), // this is duped newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, + ) + + rl.unlockedSetRelay( + netip.MustParseAddr("0.0.0.1"), + []netip.Addr{ + netip.MustParseAddr("1::1"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1::1"), + }, ) rl.Rebuild([]netip.Prefix{}) @@ -76,6 +86,11 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "[1::1]:2", rl.addrs[8].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) + // assert relay deduplicated + assert.Len(t, rl.relays, 2) + assert.Equal(t, "1.2.3.4", rl.relays[0].String()) + assert.Equal(t, "1::1", rl.relays[1].String()) + // Ensure we can hoist a specific ipv4 range over anything else rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") @@ -98,11 +113,11 @@ func TestRemoteList_Rebuild(t *testing.T) { } func BenchmarkFullRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), @@ -112,19 +127,19 @@ func BenchmarkFullRebuild(b *testing.B) { newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -160,11 +175,11 @@ func BenchmarkFullRebuild(b *testing.B) { } func BenchmarkSortRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), @@ -174,19 +189,19 @@ func BenchmarkSortRebuild(b *testing.B) { newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -224,19 +239,19 @@ func BenchmarkSortRebuild(b *testing.B) { }) } -func newIp4AndPortFromString(s string) *Ip4AndPort { +func newIp4AndPortFromString(s string) *V4AddrPort { a := netip.MustParseAddrPort(s) v4Addr := a.Addr().As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), Port: uint32(a.Port()), } } -func newIp6AndPortFromString(s string) *Ip6AndPort { +func newIp6AndPortFromString(s string) *V6AddrPort { a := netip.MustParseAddrPort(s) v6Addr := a.Addr().As16() - return &Ip6AndPort{ + return &V6AddrPort{ Hi: binary.BigEndian.Uint64(v6Addr[:8]), Lo: binary.BigEndian.Uint64(v6Addr[8:]), Port: uint32(a.Port()), diff --git a/service/service.go b/service/service.go index 4ddd301..4339677 100644 --- a/service/service.go +++ b/service/service.go @@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) { }, }) - ipNet := device.Cidr() + ipNet := device.Networks() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ diff --git a/service/service_test.go b/service/service_test.go index 3176209..613758e 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,17 +10,17 @@ import ( "dario.cat/mergo" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/e2e" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" ) type m map[string]interface{} -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) - caB, err := caCrt.MarshalToPEM() +func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) + caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } @@ -79,7 +79,7 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ diff --git a/ssh.go b/ssh.go index 2ff0954..203166c 100644 --- a/ssh.go +++ b/ssh.go @@ -77,9 +77,6 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { - //TODO conntrack list - //TODO print firewall rules or hash? - listen := c.GetString("sshd.listen", "") if listen == "" { return nil, fmt.Errorf("sshd.listen must be provided") @@ -93,7 +90,6 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return nil, fmt.Errorf("sshd.listen can not use port 22") } - //TODO: no good way to reload this right now hostKeyPathOrKey := c.GetString("sshd.host_key", "") if hostKeyPathOrKey == "" { return nil, fmt.Errorf("sshd.host_key must be provided") @@ -320,7 +316,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", - ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip", + ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintCertFlags{} @@ -336,7 +332,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-tunnel", - ShortDescription: "Prints json details about a tunnel for the provided vpn ip", + ShortDescription: "Prints json details about a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} @@ -364,7 +360,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "change-remote", - ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip", + ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshChangeRemoteFlags{} @@ -378,7 +374,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "close-tunnel", - ShortDescription: "Closes a tunnel for the provided vpn ip", + ShortDescription: "Closes a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCloseTunnelFlags{} @@ -392,7 +388,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "create-tunnel", - ShortDescription: "Creates a tunnel for the provided vpn ip and address", + ShortDescription: "Creates a tunnel for the provided vpn address", Help: "The lighthouses will be queried for real addresses but you can provide one as well.", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) @@ -407,8 +403,8 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "query-lighthouse", - ShortDescription: "Query the lighthouses for the provided vpn ip", - Help: "This command is asynchronous. Only currently known udp ips will be printed.", + ShortDescription: "Query the lighthouses for the provided vpn address", + Help: "This command is asynchronous. Only currently known udp addresses will be printed.", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { return sshQueryLighthouse(f, fs, a, w) }, @@ -418,7 +414,6 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { - //TODO: error return nil } @@ -430,7 +425,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 + return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0 }) if fs.Json || fs.Pretty { @@ -441,13 +436,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er err := js.Encode(hm) if err != nil { - //TODO return nil } } else { for _, v := range hm { - err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs)) + err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs)) if err != nil { return err } @@ -460,13 +454,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { - //TODO: error return nil } type lighthouseInfo struct { - VpnIp string `json:"vpnIp"` - Addrs *CacheMap `json:"addrs"` + VpnAddr string `json:"vpnAddr"` + Addrs *CacheMap `json:"addrs"` } lightHouse.RLock() @@ -474,15 +467,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr x := 0 for k, v := range lightHouse.addrMap { addrMap[x] = lighthouseInfo{ - VpnIp: k.String(), - Addrs: v.CopyCache(), + VpnAddr: k.String(), + Addrs: v.CopyCache(), } x++ } lightHouse.RUnlock() sort.Slice(addrMap, func(i, j int) bool { - return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0 + return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0 }) if fs.Json || fs.Pretty { @@ -493,7 +486,6 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr err := js.Encode(addrMap) if err != nil { - //TODO return nil } @@ -503,7 +495,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr if err != nil { return err } - err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b))) + err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b))) if err != nil { return err } @@ -541,20 +533,20 @@ func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } var cm *CacheMap - rl := ifce.lightHouse.Query(vpnIp) + rl := ifce.lightHouse.Query(vpnAddr) if rl != nil { cm = rl.CopyCache() } @@ -564,26 +556,25 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCloseTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } if !flags.LocalOnly { @@ -605,29 +596,28 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCreateTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } @@ -640,7 +630,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) + hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) if addr.IsValid() { hostInfo.SetRemote(addr) } @@ -651,12 +641,11 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshChangeRemoteFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } if flags.Address == "" { @@ -668,18 +657,18 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("Address could not be parsed") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } hostInfo.SetRemote(addr) @@ -781,33 +770,31 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintCertFlags) if !ok { - //TODO: error return nil } - cert := ifce.pki.GetCertState().Certificate + cert := ifce.pki.getCertState().GetDefaultCertificate() if len(a) > 0 { - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } - cert = hostInfo.GetCert() + cert = hostInfo.GetCert().Certificate } if args.Json || args.Pretty { b, err := cert.MarshalJSON() if err != nil { - //TODO: handle it return nil } @@ -816,7 +803,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit err := json.Indent(buf, b, "", " ") b = buf.Bytes() if err != nil { - //TODO: handle it return nil } } @@ -825,9 +811,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit } if args.Raw { - b, err := cert.MarshalToPEM() + b, err := cert.MarshalPEM() if err != nil { - //TODO: handle it return nil } @@ -840,7 +825,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { - //TODO: error w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) return nil } @@ -856,15 +840,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp netip.Addr + PeerAddr netip.Addr LocalIndex uint32 RemoteIndex uint32 RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp netip.Addr - RelayForIps []RelayFor + NebulaAddr netip.Addr + RelayForAddrs []RelayFor } type CmdOutput struct { @@ -880,16 +864,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } for k, v := range relays { - ro := RelayOutput{NebulaIp: v.vpnIp} + ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]} co.Relays = append(co.Relays, &ro) - relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) + relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0]) if relayHI == nil { - ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) + ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")}) continue } - for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { + for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() { rf := RelayFor{Error: nil} - r, ok := relayHI.relayState.GetRelayForByIp(vpnIp) + r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr) if ok { t := "" switch r.Type { @@ -913,19 +897,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.LocalIndex = r.LocalIndex rf.RemoteIndex = r.RemoteIndex - rf.PeerIp = r.PeerIp + rf.PeerAddr = r.PeerAddr rf.Type = t rf.State = s if rf.LocalIndex != k { rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) + relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr) if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } - ro.RelayForIps = append(ro.RelayForIps, rf) + ro.RelayForAddrs = append(ro.RelayForAddrs, rf) } } err := enc.Encode(co) @@ -938,26 +922,25 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } enc := json.NewEncoder(w.GetWriter()) @@ -971,13 +954,15 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { data := struct { - Name string `json:"name"` - Cidr string `json:"cidr"` + Name string `json:"name"` + Cidr []netip.Prefix `json:"cidr"` }{ Name: ifce.inside.Name(), - Cidr: ifce.inside.Cidr().String(), + Cidr: make([]netip.Prefix, len(ifce.inside.Networks())), } + copy(data.Cidr, ifce.inside.Networks()) + flags, ok := fs.(*sshDeviceInfoFlags) if !ok { return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) diff --git a/sshd/command.go b/sshd/command.go index 900b01e..66646a6 100644 --- a/sshd/command.go +++ b/sshd/command.go @@ -57,7 +57,6 @@ func execCommand(c *Command, args []string, w StringWriter) error { func dumpCommands(c *radix.Tree, w StringWriter) { err := w.WriteLine("Available commands:") if err != nil { - //TODO: log return } @@ -67,10 +66,7 @@ func dumpCommands(c *radix.Tree, w StringWriter) { } sort.Strings(cmds) - err = w.Write(strings.Join(cmds, "\n") + "\n\n") - if err != nil { - //TODO: log - } + _ = w.Write(strings.Join(cmds, "\n") + "\n\n") } func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { @@ -119,8 +115,6 @@ func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error) // We are printing a specific commands help text cmd, err := lookupCommand(commands, a[0]) if err != nil { - //TODO: handle error - //TODO: message the user return } diff --git a/sshd/server.go b/sshd/server.go index 9e8c721..c151f91 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -80,9 +80,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { s.config = &ssh.ServerConfig{ PublicKeyCallback: cc.Authenticate, - //TODO: AuthLogCallback: s.authAttempt, - //TODO: version string - ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), + ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), } s.RegisterCommand(&Command{ diff --git a/sshd/session.go b/sshd/session.go index bba2a55..7c5869e 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -62,7 +62,6 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) { func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { for req := range in { var err error - //TODO: maybe support window sizing? switch req.Type { case "shell": if s.term == nil { @@ -89,9 +88,7 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { req.Reply(true, nil) s.dispatchCommand(payload.Value, &stringWriter{channel}) - //TODO: Fix error handling and report the proper status back status := struct{ Status uint32 }{uint32(0)} - //TODO: I think this is how we shut down a shell as well? channel.SendRequest("exit-status", false, ssh.Marshal(status)) channel.Close() return @@ -110,7 +107,6 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { } func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal { - //TODO: PS1 with nebula cert name term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ") term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { // key 9 is tab @@ -137,7 +133,6 @@ func (s *session) handleInput(channel ssh.Channel) { for { line, err := s.term.ReadLine() if err != nil { - //TODO: log break } @@ -148,7 +143,6 @@ func (s *session) handleInput(channel ssh.Channel) { func (s *session) dispatchCommand(line string, w StringWriter) { args, err := shlex.Split(line, true) if err != nil { - //todo: LOG IT return } @@ -159,13 +153,11 @@ func (s *session) dispatchCommand(line string, w StringWriter) { c, err := lookupCommand(s.commands, args[0]) if err != nil { - //TODO: handle the error return } if c == nil { err := w.WriteLine(fmt.Sprintf("did not understand: %s", line)) - //TODO: log error _ = err dumpCommands(s.commands, w) @@ -177,10 +169,7 @@ func (s *session) dispatchCommand(line string, w StringWriter) { return } - err = execCommand(c, args[1:], w) - if err != nil { - //TODO: log the error - } + _ = execCommand(c, args[1:], w) return } diff --git a/test/assert.go b/test/assert.go index 6c6c795..d34252e 100644 --- a/test/assert.go +++ b/test/assert.go @@ -2,6 +2,7 @@ package test import ( "fmt" + "net/netip" "reflect" "testing" "time" @@ -24,6 +25,11 @@ func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) { } func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool { + if v1.Type() == v2.Type() && v1.Type() == reflect.TypeOf(netip.Addr{}) { + // Ignore netip.Addr types since they reuse an interned global value + return false + } + switch v1.Kind() { case reflect.Array: for i := 0; i < v1.Len(); i++ { diff --git a/test/tun.go b/test/tun.go index fbf5829..b29d61c 100644 --- a/test/tun.go +++ b/test/tun.go @@ -16,8 +16,8 @@ func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() netip.Prefix { - return netip.Prefix{} +func (NoopTun) Networks() []netip.Prefix { + return []netip.Prefix{} } func (NoopTun) Name() string { diff --git a/timeout_test.go b/timeout_test.go index 4c6364e..db36fec 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: netip.MustParseAddr("0.0.0.1")}, - {LocalIP: netip.MustParseAddr("0.0.0.2")}, - {LocalIP: netip.MustParseAddr("0.0.0.3")}, - {LocalIP: netip.MustParseAddr("0.0.0.4")}, + {LocalAddr: netip.MustParseAddr("0.0.0.1")}, + {LocalAddr: netip.MustParseAddr("0.0.0.2")}, + {LocalAddr: netip.MustParseAddr("0.0.0.3")}, + {LocalAddr: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1) diff --git a/udp/conn.go b/udp/conn.go index fa4e443..895b0df 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -4,28 +4,19 @@ import ( "net/netip" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) const MTU = 9001 type EncReader func( addr netip.AddrPort, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, + payload []byte, ) type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) + ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error @@ -39,7 +30,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { +func (NoopConn) ListenOut(_ EncReader) { return } func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { diff --git a/udp/temp.go b/udp/temp.go deleted file mode 100644 index b281906..0000000 --- a/udp/temp.go +++ /dev/null @@ -1,10 +0,0 @@ -package udp - -import ( - "net/netip" -) - -//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare - -// TODO: IPV6-WORK this can likely be removed now -type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 2d84536..06a4d53 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -15,8 +15,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) type GenericConn struct { @@ -60,7 +58,7 @@ func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { } func (u *GenericConn) ReloadConfig(c *config.C) { - // TODO + } func NewUDPStatsEmitter(udpConns []Conn) func() { @@ -72,12 +70,8 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *GenericConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f return } - r( - netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), - plaintext[:0], - buffer[:n], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 2eee76e..32a567e 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -14,13 +14,9 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) -//TODO: make it support reload as best you can! - type StdConn struct { sysFd int isV4 bool @@ -59,7 +55,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in } } - //TODO: support multiple listening IPs (for limiting ipv6) var sa unix.Sockaddr if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} @@ -74,11 +69,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - //TODO: this may be useful for forcing threads into specific cores - //unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU, x) - //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) - //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } @@ -120,15 +110,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } } -func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} +func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - nb := make([]byte, 12, 12) - //TODO: should we track this? - //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { @@ -142,26 +126,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - //metric.Update(int64(n)) for i := 0; i < n; i++ { + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) - //TODO: IPV6-WORK what is not ok? } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) - //TODO: IPV6-WORK what is not ok? } - r( - netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), - plaintext[:0], - buffers[i][:msgs[i].Len], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } } } @@ -235,8 +207,6 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { return &net.OpError{Op: "sendto", Err: err} } - //TODO: handle incomplete writes - return nil } } @@ -266,8 +236,6 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { return &net.OpError{Op: "sendto", Err: err} } - //TODO: handle incomplete writes - return nil } } @@ -314,7 +282,6 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { } func (u *StdConn) Close() error { - //TODO: this will not interrupt the read loop return syscall.Close(u.sysFd) } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index 523968c..de8f1cd 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -39,7 +39,6 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) - //TODO: this is still silly, no need for an array vs := []iovec{ {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 87a0de7..48c5a97 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -42,7 +42,6 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) - //TODO: this is still silly, no need for an array vs := []iovec{ {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index ee7e1e0..585b642 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -18,9 +18,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" - "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" ) @@ -118,12 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *RIOConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -133,17 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - r( - netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), - plaintext[:0], - buffer[:n], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index f03a353..8d5e6c1 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -10,7 +10,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" ) @@ -107,18 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) - +func (u *TesterConn) ListenOut(r EncReader) { for { p, ok := <-u.RxPackets if !ok { return } - r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, p.Data) } }