From a56a97e5c378d22e1b8d468d4d7fbbd2184a6e9c Mon Sep 17 00:00:00 2001 From: John Maguire Date: Mon, 3 Apr 2023 13:59:38 -0400 Subject: [PATCH 1/8] Add ability to encrypt CA private key at rest (#386) Fixes #8. `nebula-cert ca` now supports encrypting the CA's private key with a passphrase. Pass `-encrypt` in order to be prompted for a passphrase. Encryption is performed using AES-256-GCM and Argon2id for KDF. KDF parameters default to RFC recommendations, but can be overridden via CLI flags `-argon-memory`, `-argon-parallelism`, and `-argon-iterations`. --- CHANGELOG.md | 7 + cert/cert.go | 151 +++++++++++++++- cert/cert.pb.go | 281 ++++++++++++++++++++++++++++-- cert/cert.proto | 20 ++- cert/cert_test.go | 85 +++++++++ cert/crypto.go | 140 +++++++++++++++ cert/crypto_test.go | 25 +++ cmd/nebula-cert/ca.go | 83 +++++++-- cmd/nebula-cert/ca_test.go | 95 +++++++++- cmd/nebula-cert/main.go | 4 +- cmd/nebula-cert/passwords.go | 28 +++ cmd/nebula-cert/passwords_test.go | 10 ++ cmd/nebula-cert/sign.go | 35 +++- cmd/nebula-cert/sign_test.go | 138 ++++++++++++--- go.mod | 2 +- go.sum | 2 - 16 files changed, 1037 insertions(+), 69 deletions(-) create mode 100644 cert/crypto.go create mode 100644 cert/crypto_test.go create mode 100644 cmd/nebula-cert/passwords.go create mode 100644 cmd/nebula-cert/passwords_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f0ebbb..e1c4c00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- `nebula-cert ca` now supports encrypting the CA's private key with a + passphrase. Pass `-encrypt` in order to be prompted for a passphrase. + Encryption is performed using AES-256-GCM and Argon2id for KDF. KDF + parameters default to RFC recommendations, but can be overridden via CLI + flags `-argon-memory`, `-argon-parallelism`, and `-argon-iterations`. + ## [1.6.1] - 2022-09-26 ### Fixed diff --git a/cert/cert.go b/cert/cert.go index f3df89c..216efcf 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -9,7 +9,9 @@ import ( "encoding/hex" "encoding/json" "encoding/pem" + "errors" "fmt" + "math" "net" "time" @@ -21,11 +23,12 @@ import ( const publicKeyLen = 32 const ( - CertBanner = "NEBULA CERTIFICATE" - X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" - X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" - Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" - Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" + CertBanner = "NEBULA CERTIFICATE" + X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" + X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" + EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" + Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" + Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" ) type NebulaCertificate struct { @@ -48,8 +51,21 @@ type NebulaCertificateDetails struct { InvertedGroups map[string]struct{} } +type NebulaEncryptedData struct { + EncryptionMetadata NebulaEncryptionMetadata + Ciphertext []byte +} + +type NebulaEncryptionMetadata struct { + EncryptionAlgorithm string + Argon2Parameters Argon2Parameters +} + type m map[string]interface{} +// Returned if we try to unmarshal an encrypted private key without a passphrase +var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") + // UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { if len(b) == 0 { @@ -144,6 +160,30 @@ func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte { return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key}) } +// EncryptAndMarshalX25519PrivateKey is a simple helper to encrypt and PEM encode an X25519 private key +func EncryptAndMarshalEd25519PrivateKey(b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { + ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) + if err != nil { + return nil, err + } + + b, err = proto.Marshal(&RawNebulaEncryptedData{ + EncryptionMetadata: &RawNebulaEncryptionMetadata{ + EncryptionAlgorithm: "AES-256-GCM", + Argon2Parameters: &RawNebulaArgon2Parameters{ + Version: kdfParams.version, + Memory: kdfParams.Memory, + Parallelism: uint32(kdfParams.Parallelism), + Iterations: kdfParams.Iterations, + Salt: kdfParams.salt, + }, + }, + Ciphertext: ciphertext, + }) + + return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil +} + // UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b // or an error on failure func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) { @@ -168,9 +208,13 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { if k == nil { return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") } - if k.Type != Ed25519PrivateKeyBanner { + + if k.Type == EncryptedEd25519PrivateKeyBanner { + return nil, r, ErrPrivateKeyEncrypted + } else if k.Type != Ed25519PrivateKeyBanner { return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner") } + if len(k.Bytes) != ed25519.PrivateKeySize { return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } @@ -178,6 +222,101 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { return k.Bytes, r, nil } +// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert into its +// protobuf-generated struct. +func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rned RawNebulaEncryptedData + err := proto.Unmarshal(b, &rned) + if err != nil { + return nil, err + } + + if rned.EncryptionMetadata == nil { + return nil, fmt.Errorf("encoded EncryptionMetadata was nil") + } + + if rned.EncryptionMetadata.Argon2Parameters == nil { + return nil, fmt.Errorf("encoded Argon2Parameters was nil") + } + + params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) + if err != nil { + return nil, err + } + + ned := NebulaEncryptedData{ + EncryptionMetadata: NebulaEncryptionMetadata{ + EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, + Argon2Parameters: *params, + }, + Ciphertext: rned.Ciphertext, + } + + return &ned, nil +} + +func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { + if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { + return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) + } + if params.Memory <= 0 || params.Memory > math.MaxUint32 { + return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { + return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { + return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return &Argon2Parameters{ + version: rune(params.Version), + Memory: uint32(params.Memory), + Parallelism: uint8(params.Parallelism), + Iterations: uint32(params.Iterations), + salt: params.Salt, + }, nil + +} + +// DecryptAndUnmarshalEd25519PrivateKey will try to pem decode and decrypt an Ed25519 private key with +// the given passphrase, returning any other bytes b or an error on failure +func DecryptAndUnmarshalEd25519PrivateKey(passphrase, b []byte) (ed25519.PrivateKey, []byte, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + + if k.Type != EncryptedEd25519PrivateKeyBanner { + return nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519 private key banner") + } + + ned, err := UnmarshalNebulaEncryptedData(k.Bytes) + if err != nil { + return nil, r, err + } + + var bytes []byte + switch ned.EncryptionMetadata.EncryptionAlgorithm { + case "AES-256-GCM": + bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) + if err != nil { + return nil, r, err + } + default: + return nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) + } + + if len(bytes) != ed25519.PrivateKeySize { + return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") + } + + return bytes, r, nil +} + // MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key func MarshalX25519PublicKey(b []byte) []byte { return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) diff --git a/cert/cert.pb.go b/cert/cert.pb.go index 094aefb..1aa1b4b 100644 --- a/cert/cert.pb.go +++ b/cert/cert.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.0 -// protoc v3.20.0 +// protoc v3.19.4 // source: cert.proto package cert @@ -188,6 +188,195 @@ func (x *RawNebulaCertificateDetails) GetIssuer() []byte { return nil } +type RawNebulaEncryptedData struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EncryptionMetadata *RawNebulaEncryptionMetadata `protobuf:"bytes,1,opt,name=EncryptionMetadata,proto3" json:"EncryptionMetadata,omitempty"` + Ciphertext []byte `protobuf:"bytes,2,opt,name=Ciphertext,proto3" json:"Ciphertext,omitempty"` +} + +func (x *RawNebulaEncryptedData) Reset() { + *x = RawNebulaEncryptedData{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaEncryptedData) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaEncryptedData) ProtoMessage() {} + +func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead. +func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{2} +} + +func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata { + if x != nil { + return x.EncryptionMetadata + } + return nil +} + +func (x *RawNebulaEncryptedData) GetCiphertext() []byte { + if x != nil { + return x.Ciphertext + } + return nil +} + +type RawNebulaEncryptionMetadata struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EncryptionAlgorithm string `protobuf:"bytes,1,opt,name=EncryptionAlgorithm,proto3" json:"EncryptionAlgorithm,omitempty"` + Argon2Parameters *RawNebulaArgon2Parameters `protobuf:"bytes,2,opt,name=Argon2Parameters,proto3" json:"Argon2Parameters,omitempty"` +} + +func (x *RawNebulaEncryptionMetadata) Reset() { + *x = RawNebulaEncryptionMetadata{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaEncryptionMetadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaEncryptionMetadata) ProtoMessage() {} + +func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead. +func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{3} +} + +func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string { + if x != nil { + return x.EncryptionAlgorithm + } + return "" +} + +func (x *RawNebulaEncryptionMetadata) GetArgon2Parameters() *RawNebulaArgon2Parameters { + if x != nil { + return x.Argon2Parameters + } + return nil +} + +type RawNebulaArgon2Parameters struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Version int32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` // rune in Go + Memory uint32 `protobuf:"varint,2,opt,name=memory,proto3" json:"memory,omitempty"` + Parallelism uint32 `protobuf:"varint,4,opt,name=parallelism,proto3" json:"parallelism,omitempty"` // uint8 in Go + Iterations uint32 `protobuf:"varint,3,opt,name=iterations,proto3" json:"iterations,omitempty"` + Salt []byte `protobuf:"bytes,5,opt,name=salt,proto3" json:"salt,omitempty"` +} + +func (x *RawNebulaArgon2Parameters) Reset() { + *x = RawNebulaArgon2Parameters{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaArgon2Parameters) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaArgon2Parameters) ProtoMessage() {} + +func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead. +func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{4} +} + +func (x *RawNebulaArgon2Parameters) GetVersion() int32 { + if x != nil { + return x.Version + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetMemory() uint32 { + if x != nil { + return x.Memory + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetParallelism() uint32 { + if x != nil { + return x.Parallelism + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetIterations() uint32 { + if x != nil { + return x.Iterations + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetSalt() []byte { + if x != nil { + return x.Salt + } + return nil +} + var File_cert_proto protoreflect.FileDescriptor var file_cert_proto_rawDesc = []byte{ @@ -215,9 +404,38 @@ var file_cert_proto_rawDesc = []byte{ 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, - 0x72, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, - 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, - 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, + 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, + 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x22, + 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, + 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, + 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, + 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, + 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, + 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, + 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, 0x10, 0x41, 0x72, + 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, + 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, + 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, + 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, + 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, + 0x73, 0x61, 0x6c, 0x74, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, + 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -232,18 +450,23 @@ func file_cert_proto_rawDescGZIP() []byte { return file_cert_proto_rawDescData } -var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_cert_proto_goTypes = []interface{}{ (*RawNebulaCertificate)(nil), // 0: cert.RawNebulaCertificate (*RawNebulaCertificateDetails)(nil), // 1: cert.RawNebulaCertificateDetails + (*RawNebulaEncryptedData)(nil), // 2: cert.RawNebulaEncryptedData + (*RawNebulaEncryptionMetadata)(nil), // 3: cert.RawNebulaEncryptionMetadata + (*RawNebulaArgon2Parameters)(nil), // 4: cert.RawNebulaArgon2Parameters } var file_cert_proto_depIdxs = []int32{ 1, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 3, // 1: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata + 4, // 2: cert.RawNebulaEncryptionMetadata.Argon2Parameters:type_name -> cert.RawNebulaArgon2Parameters + 3, // [3:3] is the sub-list for method output_type + 3, // [3:3] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name } func init() { file_cert_proto_init() } @@ -276,6 +499,42 @@ func file_cert_proto_init() { return nil } } + file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaEncryptedData); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaEncryptionMetadata); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaArgon2Parameters); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -283,7 +542,7 @@ func file_cert_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_cert_proto_rawDesc, NumEnums: 0, - NumMessages: 2, + NumMessages: 5, NumExtensions: 0, NumServices: 0, }, diff --git a/cert/cert.proto b/cert/cert.proto index e135dd1..be7b132 100644 --- a/cert/cert.proto +++ b/cert/cert.proto @@ -26,4 +26,22 @@ message RawNebulaCertificateDetails { // sha-256 of the issuer certificate, if this field is blank the cert is self-signed bytes Issuer = 9; -} \ No newline at end of file +} + +message RawNebulaEncryptedData { + RawNebulaEncryptionMetadata EncryptionMetadata = 1; + bytes Ciphertext = 2; +} + +message RawNebulaEncryptionMetadata { + string EncryptionAlgorithm = 1; + RawNebulaArgon2Parameters Argon2Parameters = 2; +} + +message RawNebulaArgon2Parameters { + int32 version = 1; // rune in Go + uint32 memory = 2; + uint32 parallelism = 4; // uint8 in Go + uint32 iterations = 3; + bytes salt = 5; +} diff --git a/cert/cert_test.go b/cert/cert_test.go index 5a82741..ece9f7f 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -578,6 +578,91 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA assert.EqualError(t, err, "input did not contain a valid PEM encoded block") } +func TestDecryptAndUnmarshalEd25519PrivateKey(t *testing.T) { + passphrase := []byte("DO NOT USE THIS KEY") + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + shortKey := []byte(`# A key which, once decrypted, is too short +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 +k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe +GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs +rQr3bdH3Oy/WiYU= +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner (not encrypted) +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG +XgLvodMXZJuaFPssp+WwtA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + + keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, err := DecryptAndUnmarshalEd25519PrivateKey(passphrase, keyBundle) + assert.Nil(t, err) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest) + assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + + // Fail due to invalid banner + k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest) + assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519 private key banner") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to invalid passphrase + k, rest, err = DecryptAndUnmarshalEd25519PrivateKey([]byte("invalid passphrase"), privKey) + assert.EqualError(t, err, "invalid passphrase or corrupt private key") + assert.Nil(t, k) + assert.Equal(t, rest, []byte{}) +} + +func TestEncryptAndMarshalEd25519PrivateKey(t *testing.T) { + // Having proved that decryption works correctly above, we can test the + // encryption function produces a value which can be decrypted + passphrase := []byte("passphrase") + bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + kdfParams := NewArgon2Parameters(64*1024, 4, 3) + key, err := EncryptAndMarshalEd25519PrivateKey(bytes, passphrase, kdfParams) + assert.Nil(t, err) + + // Verify the "key" can be decrypted successfully + k, rest, err := DecryptAndUnmarshalEd25519PrivateKey(passphrase, key) + assert.Len(t, k, 64) + assert.Equal(t, rest, []byte{}) + assert.Nil(t, err) + + // EncryptAndMarshalEd25519PrivateKey does not create any errors itself +} + func TestUnmarshalX25519PrivateKey(t *testing.T) { privKey := []byte(`# A good key -----BEGIN NEBULA X25519 PRIVATE KEY----- diff --git a/cert/crypto.go b/cert/crypto.go new file mode 100644 index 0000000..94f4c48 --- /dev/null +++ b/cert/crypto.go @@ -0,0 +1,140 @@ +package cert + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" + + "golang.org/x/crypto/argon2" +) + +// KDF factors +type Argon2Parameters struct { + version rune + Memory uint32 // KiB + Parallelism uint8 + Iterations uint32 + salt []byte +} + +// Returns a new Argon2Parameters object with current version set +func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters { + return &Argon2Parameters{ + version: argon2.Version, + Memory: memory, // KiB + Parallelism: parallelism, + Iterations: iterations, + } +} + +// Encrypts data using AES-256-GCM and the Argon2id key derivation function +func aes256Encrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) { + key, err := aes256DeriveKey(passphrase, kdfParams) + if err != nil { + return nil, err + } + + // this should never happen, but since this dictates how our calls into the + // aes package behave and could be catastraphic, let's sanity check this + if len(key) != 32 { + return nil, fmt.Errorf("invalid AES-256 key length (%d) - cowardly refusing to encrypt", len(key)) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nil, nonce, data, nil) + blob := joinNonceCiphertext(nonce, ciphertext) + + return blob, nil +} + +// Decrypts data using AES-256-GCM and the Argon2id key derivation function +// Expects the data to include an Argon2id parameter string before the encrypted data +func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) { + key, err := aes256DeriveKey(passphrase, kdfParams) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + + nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize()) + if err != nil { + return nil, err + } + + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("invalid passphrase or corrupt private key") + } + + return plaintext, nil +} + +func aes256DeriveKey(passphrase []byte, params *Argon2Parameters) ([]byte, error) { + if params.salt == nil { + params.salt = make([]byte, 32) + if _, err := rand.Read(params.salt); err != nil { + return nil, err + } + } + + // keySize of 32 bytes will result in AES-256 encryption + key, err := deriveKey(passphrase, 32, params) + if err != nil { + return nil, err + } + + return key, nil +} + +// Derives a key from a passphrase using Argon2id +func deriveKey(passphrase []byte, keySize uint32, params *Argon2Parameters) ([]byte, error) { + if params.version != argon2.Version { + return nil, fmt.Errorf("incompatible Argon2 version: %d", params.version) + } + + if params.salt == nil { + return nil, fmt.Errorf("salt must be set in argon2Parameters") + } else if len(params.salt) < 16 { + return nil, fmt.Errorf("salt must be at least 128 bits") + } + + key := argon2.IDKey(passphrase, params.salt, params.Iterations, params.Memory, params.Parallelism, keySize) + + return key, nil +} + +// Prepends nonce to ciphertext +func joinNonceCiphertext(nonce []byte, ciphertext []byte) []byte { + return append(nonce, ciphertext...) +} + +// Splits nonce from ciphertext +func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) { + if len(blob) <= nonceSize { + return nil, nil, fmt.Errorf("invalid ciphertext blob - blob shorter than nonce length") + } + + return blob[:nonceSize], blob[nonceSize:], nil +} diff --git a/cert/crypto_test.go b/cert/crypto_test.go new file mode 100644 index 0000000..c2e61df --- /dev/null +++ b/cert/crypto_test.go @@ -0,0 +1,25 @@ +package cert + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/argon2" +) + +func TestNewArgon2Parameters(t *testing.T) { + p := NewArgon2Parameters(64*1024, 4, 3) + assert.EqualValues(t, &Argon2Parameters{ + version: argon2.Version, + Memory: 64 * 1024, + Parallelism: 4, + Iterations: 3, + }, p) + p = NewArgon2Parameters(2*1024*1024, 2, 1) + assert.EqualValues(t, &Argon2Parameters{ + version: argon2.Version, + Memory: 2 * 1024 * 1024, + Parallelism: 2, + Iterations: 1, + }, p) +} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index ce8d5fa..b4f25c9 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "os" "strings" @@ -17,15 +18,19 @@ import ( ) type caFlags struct { - set *flag.FlagSet - name *string - duration *time.Duration - outKeyPath *string - outCertPath *string - outQRPath *string - groups *string - ips *string - subnets *string + set *flag.FlagSet + name *string + duration *time.Duration + outKeyPath *string + outCertPath *string + outQRPath *string + groups *string + ips *string + subnets *string + argonMemory *uint + argonIterations *uint + argonParallelism *uint + encryption *bool } func newCaFlags() *caFlags { @@ -39,10 +44,28 @@ func newCaFlags() *caFlags { cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") + cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") + cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") + cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") + cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") return &cf } -func ca(args []string, out io.Writer, errOut io.Writer) error { +func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert.Argon2Parameters, error) { + if memory <= 0 || memory > math.MaxUint32 { + return nil, newHelpErrorf("-argon-memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if parallelism <= 0 || parallelism > math.MaxUint8 { + return nil, newHelpErrorf("-argon-parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if iterations <= 0 || iterations > math.MaxUint32 { + return nil, newHelpErrorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil +} + +func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { cf := newCaFlags() err := cf.set.Parse(args) if err != nil { @@ -58,6 +81,12 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err } + var kdfParams *cert.Argon2Parameters + if *cf.encryption { + if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil { + return err + } + } if *cf.duration <= 0 { return &helpError{"-duration must be greater than 0"} @@ -109,6 +138,28 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { } } + var passphrase []byte + if *cf.encryption { + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if err == ErrNoTerminal { + return fmt.Errorf("out-key must be encrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading passphrase: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + + if len(passphrase) == 0 { + return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext") + } + } + pub, rawPriv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return fmt.Errorf("error while generating ed25519 keys: %s", err) @@ -140,7 +191,17 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while signing: %s", err) } - err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600) + if *cf.encryption { + b, err := cert.EncryptAndMarshalEd25519PrivateKey(rawPriv, passphrase, kdfParams) + if err != nil { + return fmt.Errorf("error while encrypting out-key: %s", err) + } + + err = ioutil.WriteFile(*cf.outKeyPath, b, 0600) + } else { + err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600) + } + if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 372a4f1..0ce9182 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -5,8 +5,11 @@ package main import ( "bytes" + "encoding/pem" + "errors" "io/ioutil" "os" + "strings" "testing" "time" @@ -26,8 +29,16 @@ func Test_caHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" ca : create a self signed certificate authority\n"+ + " -argon-iterations uint\n"+ + " \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+ + " -argon-memory uint\n"+ + " \tOptional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase (default 2097152)\n"+ + " -argon-parallelism uint\n"+ + " \tOptional: Argon2 parallelism parameter used for encrypted private key passphrase (default 4)\n"+ " -duration duration\n"+ " \tOptional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\" (default 8760h0m0s)\n"+ + " -encrypt\n"+ + " \tOptional: prompt for passphrase and write out-key in an encrypted format\n"+ " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ @@ -50,18 +61,38 @@ func Test_ca(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} + nopw := &StubPasswordReader{ + password: []byte(""), + err: nil, + } + + errpw := &StubPasswordReader{ + password: []byte(""), + err: errors.New("stub error"), + } + + passphrase := []byte("DO NOT USE THIS KEY") + testpw := &StubPasswordReader{ + password: passphrase, + err: nil, + } + + pwPromptOb := "Enter passphrase: " + // required args - assertHelpError(t, ca([]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb), "-name is required") + assertHelpError(t, ca( + []string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, + ), "-name is required") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only ips - assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only subnets - assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -69,7 +100,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, ca(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -82,7 +113,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -96,7 +127,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb)) + assert.Nil(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -122,19 +153,65 @@ func Test_ca(t *testing.T) { assert.Equal(t, "", lCrt.Details.Issuer) assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey)) + // test encrypted key + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.Nil(t, ca(args, ob, eb, testpw)) + assert.Equal(t, pwPromptOb, ob.String()) + assert.Equal(t, "", eb.String()) + + // read encrypted key file and verify default params + rb, _ = ioutil.ReadFile(keyF.Name()) + k, _ := pem.Decode(rb) + ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) + assert.Nil(t, err) + // we won't know salt in advance, so just check start of string + assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) + assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) + assert.Equal(t, uint32(1), ned.EncryptionMetadata.Argon2Parameters.Iterations) + + // verify the key is valid and decrypt-able + lKey, b, err = cert.DecryptAndUnmarshalEd25519PrivateKey(passphrase, rb) + assert.Nil(t, err) + assert.Len(t, b, 0) + assert.Len(t, lKey, 64) + + // test when reading passsword results in an error + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.Error(t, ca(args, ob, eb, errpw)) + assert.Equal(t, pwPromptOb, ob.String()) + assert.Equal(t, "", eb.String()) + + // test when user fails to enter a password + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") + assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up + assert.Equal(t, "", eb.String()) + // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb)) + assert.Nil(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA key: "+keyF.Name()) + assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -143,7 +220,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA cert: "+crtF.Name()) + assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) os.Remove(keyF.Name()) diff --git a/cmd/nebula-cert/main.go b/cmd/nebula-cert/main.go index 3fba40a..b803d30 100644 --- a/cmd/nebula-cert/main.go +++ b/cmd/nebula-cert/main.go @@ -62,11 +62,11 @@ func main() { switch args[0] { case "ca": - err = ca(args[1:], os.Stdout, os.Stderr) + err = ca(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{}) case "keygen": err = keygen(args[1:], os.Stdout, os.Stderr) case "sign": - err = signCert(args[1:], os.Stdout, os.Stderr) + err = signCert(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{}) case "print": err = printCert(args[1:], os.Stdout, os.Stderr) case "verify": diff --git a/cmd/nebula-cert/passwords.go b/cmd/nebula-cert/passwords.go new file mode 100644 index 0000000..8129560 --- /dev/null +++ b/cmd/nebula-cert/passwords.go @@ -0,0 +1,28 @@ +package main + +import ( + "errors" + "fmt" + "os" + + "golang.org/x/term" +) + +var ErrNoTerminal = errors.New("cannot read password from nonexistent terminal") + +type PasswordReader interface { + ReadPassword() ([]byte, error) +} + +type StdinPasswordReader struct{} + +func (pr StdinPasswordReader) ReadPassword() ([]byte, error) { + if !term.IsTerminal(int(os.Stdin.Fd())) { + return nil, ErrNoTerminal + } + + password, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() + + return password, err +} diff --git a/cmd/nebula-cert/passwords_test.go b/cmd/nebula-cert/passwords_test.go new file mode 100644 index 0000000..d0b64b9 --- /dev/null +++ b/cmd/nebula-cert/passwords_test.go @@ -0,0 +1,10 @@ +package main + +type StubPasswordReader struct { + password []byte + err error +} + +func (pr *StubPasswordReader) ReadPassword() ([]byte, error) { + return pr.password, pr.err +} diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 4b3b899..68104ad 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -1,6 +1,7 @@ package main import ( + "crypto/ed25519" "crypto/rand" "flag" "fmt" @@ -49,7 +50,7 @@ func newSignFlags() *signFlags { } -func signCert(args []string, out io.Writer, errOut io.Writer) error { +func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { sf := newSignFlags() err := sf.set.Parse(args) if err != nil { @@ -77,8 +78,36 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while reading ca-key: %s", err) } - caKey, _, err := cert.UnmarshalEd25519PrivateKey(rawCAKey) - if err != nil { + var caKey ed25519.PrivateKey + + // naively attempt to decode the private key as though it is not encrypted + caKey, _, err = cert.UnmarshalEd25519PrivateKey(rawCAKey) + if err == cert.ErrPrivateKeyEncrypted { + // ask for a passphrase until we get one + var passphrase []byte + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if err == ErrNoTerminal { + return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading password: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + if len(passphrase) == 0 { + return fmt.Errorf("cannot open encrypted ca-key without passphrase") + } + + caKey, _, err = cert.DecryptAndUnmarshalEd25519PrivateKey(passphrase, rawCAKey) + if err != nil { + return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + } + } else if err != nil { return fmt.Errorf("error while parsing ca-key: %s", err) } diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 4976fa3..afde357 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -6,6 +6,7 @@ package main import ( "bytes" "crypto/rand" + "errors" "io/ioutil" "os" "testing" @@ -58,17 +59,39 @@ func Test_signCert(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} + nopw := &StubPasswordReader{ + password: []byte(""), + err: nil, + } + + errpw := &StubPasswordReader{ + password: []byte(""), + err: errors.New("stub error"), + } + + passphrase := []byte("DO NOT USE THIS KEY") + testpw := &StubPasswordReader{ + password: passphrase, + err: nil, + } + // required args - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-name is required") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-ip is required") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-ip is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb), "cannot set both -in-pub and -out-key") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, + ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -76,7 +99,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-key: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key ob.Reset() @@ -86,7 +109,7 @@ func Test_signCert(t *testing.T) { defer os.Remove(caKeyF.Name()) args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-key: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -98,7 +121,7 @@ func Test_signCert(t *testing.T) { // failed to read cert args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-crt: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -110,7 +133,7 @@ func Test_signCert(t *testing.T) { defer os.Remove(caCrtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-crt: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -129,7 +152,7 @@ func Test_signCert(t *testing.T) { // failed to read pub args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading in-pub: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -141,7 +164,7 @@ func Test_signCert(t *testing.T) { defer os.Remove(inPubF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing in-pub: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -155,14 +178,14 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -170,14 +193,14 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: invalid CIDR address: a") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -191,7 +214,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to sign, root certificate does not match private key") + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -199,7 +222,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -212,7 +235,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) @@ -226,7 +249,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -268,7 +291,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -283,7 +306,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -291,14 +314,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing key: "+keyF.Name()) + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -306,14 +329,83 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing cert: "+crtF.Name()) + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) + + // create valid cert/key using encrypted CA key + os.Remove(caKeyF.Name()) + os.Remove(caCrtF.Name()) + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + + caKeyF, err = ioutil.TempFile("", "sign-cert.key") + assert.Nil(t, err) + defer os.Remove(caKeyF.Name()) + + caCrtF, err = ioutil.TempFile("", "sign-cert.crt") + assert.Nil(t, err) + defer os.Remove(caCrtF.Name()) + + // generate the encrypted key + caPub, caPriv, _ = ed25519.GenerateKey(rand.Reader) + kdfParams := cert.NewArgon2Parameters(64*1024, 4, 3) + b, _ = cert.EncryptAndMarshalEd25519PrivateKey(caPriv, passphrase, kdfParams) + caKeyF.Write(b) + + ca = cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "ca", + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Minute * 200), + PublicKey: caPub, + IsCA: true, + }, + } + b, _ = ca.MarshalToPEM() + caCrtF.Write(b) + + // test with the proper password + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Nil(t, signCert(args, ob, eb, testpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test with the wrong password + ob.Reset() + eb.Reset() + + testpw.password = []byte("invalid password") + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, testpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test with the user not entering a password + ob.Reset() + eb.Reset() + + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, nopw)) + // normally the user hitting enter on the prompt would add newlines between these + assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test an error condition + ob.Reset() + eb.Reset() + + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, errpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) } diff --git a/go.mod b/go.mod index 2b2fafa..ea42666 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 golang.org/x/net v0.8.0 golang.org/x/sys v0.6.0 + golang.org/x/term v0.6.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.29.0 @@ -42,7 +43,6 @@ require ( github.com/prometheus/procfs v0.9.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/mod v0.9.0 // indirect - golang.org/x/term v0.6.0 // indirect golang.org/x/tools v0.7.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6d3febc..5571236 100644 --- a/go.sum +++ b/go.sum @@ -153,8 +153,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI= -golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 h1:LGJsf5LRplCck6jUCH3dBL2dmycNruWNF5xugkSlfXw= golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= From 6685856b5db2e4f5e8f69eac0af5280d7d1d930c Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 3 Apr 2023 21:18:16 -0400 Subject: [PATCH 2/8] emit certificate.expiration_ttl_seconds metric (#782) --- interface.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/interface.go b/interface.go index b4822ed..af83abc 100644 --- a/interface.go +++ b/interface.go @@ -380,6 +380,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { udpStats := udp.NewUDPStatsEmitter(f.writers) + certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + for { select { case <-ctx.Done(): @@ -388,6 +390,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() + certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) } } } From fd99ce9a7137f577c26d74821290c9ef424f145e Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 4 Apr 2023 13:42:24 -0500 Subject: [PATCH 3/8] Use fewer test packets (#840) --- connection_manager.go | 28 +++++++++++++++++++--------- connection_manager_test.go | 2 ++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index 0ea1f75..14086ac 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -183,12 +183,6 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, return } - if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) { - // We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel - n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) - return - } - if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { // We have already sent a test packet and nothing was returned, this hostinfo is dead hostinfo.logger(n.l). @@ -205,10 +199,26 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, Debug("Tunnel status") if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { - if n.punchy.GetTargetEverything() { - // Maybe the remote is sending us packets but our NAT is blocking it and since we are configured to punch to all - // known remotes, go ahead and do that AND send a test packet + if !outTraffic { + // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. + // Just maintain NAT state if configured to do so. n.sendPunch(hostinfo) + n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + return + + } + + if n.punchy.GetTargetEverything() { + // This is similar to the old punchy behavior with a slight optimization. + // We aren't receiving traffic but we are sending it, punch on all known + // ips in case we need to re-prime NAT state + n.sendPunch(hostinfo) + } + + if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) { + // We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel + n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + return } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues diff --git a/connection_manager_test.go b/connection_manager_test.go index e05376d..3d79cb0 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -98,6 +98,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.NotContains(t, nc.in, hostinfo.localIndexId) // Do another traffic check tick, this host should be pending deletion now + nc.Out(hostinfo.localIndexId) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId) @@ -175,6 +176,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.in, hostinfo.localIndexId) // Do another traffic check tick, this host should be pending deletion now + nc.Out(hostinfo.localIndexId) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId) From d3fe3efcb0f07eddb459f6a85f37dcdad4f80668 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 5 Apr 2023 10:04:30 -0500 Subject: [PATCH 4/8] Fix handshake retry regression (#842) --- handshake_manager.go | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/handshake_manager.go b/handshake_manager.go index 449a4da..c8a01ca 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -53,10 +53,6 @@ type HandshakeManager struct { metricTimedOut metrics.Counter l *logrus.Logger - // vpnIps is another map similar to the pending hostmap but tracks entries in the wheel instead - // this is to avoid situations where the same vpn ip enters the wheel and causes rapid fire handshaking - vpnIps map[iputil.VpnIp]struct{} - // can be used to trigger outbound handshake for the given vpnIp trigger chan iputil.VpnIp } @@ -70,7 +66,6 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ config: config, trigger: make(chan iputil.VpnIp, config.triggerBuffer), OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), - vpnIps: map[iputil.VpnIp]struct{}{}, messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -108,7 +103,6 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { - delete(c.vpnIps, vpnIp) return } hostinfo.Lock() @@ -298,10 +292,7 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init) if created { - if _, ok := c.vpnIps[vpnIp]; !ok { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) - } - c.vpnIps[vpnIp] = struct{}{} + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) c.metricInitiated.Inc(1) } From e0553822b0c8c7d48509482f4904fa5f0945feac Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 5 Apr 2023 11:08:23 -0400 Subject: [PATCH 5/8] Use NewGCMTLS (when using experiment boringcrypto) (#803) * Use NewGCMTLS (when using experiment boringcrypto) This change only affects builds built using `GOEXPERIMENT=boringcrypto`. When built with this experiment, we use the NewGCMTLS() method exposed by goboring, which validates that the nonce is strictly monotonically increasing. This is the TLS 1.2 specification for nonce generation (which also matches the method used by the Noise Protocol) - https://github.com/golang/go/blob/go1.19/src/crypto/tls/cipher_suites.go#L520-L522 - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L235-L237 - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L250 - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/include/openssl/aead.h#L379-L381 - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/crypto/fipsmodule/cipher/e_aes.c#L1082-L1093 * need to lock around EncryptDanger in SendVia * fix link to test vector --- connection_state.go | 3 +- inside.go | 23 +++++++++-- noiseutil/boring.go | 80 +++++++++++++++++++++++++++++++++++++ noiseutil/boring_test.go | 39 ++++++++++++++++++ noiseutil/notboring.go | 14 +++++++ noiseutil/notboring_test.go | 15 +++++++ 6 files changed, 169 insertions(+), 5 deletions(-) create mode 100644 noiseutil/boring.go create mode 100644 noiseutil/boring_test.go create mode 100644 noiseutil/notboring.go create mode 100644 noiseutil/notboring_test.go diff --git a/connection_state.go b/connection_state.go index 2a7be15..4f8a577 100644 --- a/connection_state.go +++ b/connection_state.go @@ -9,6 +9,7 @@ import ( "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/noiseutil" ) const ReplayWindow = 1024 @@ -28,7 +29,7 @@ type ConnectionState struct { } func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { - cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256) + cs := noise.NewCipherSuite(noise.DH25519, noiseutil.CipherAESGCM, noise.HashSHA256) if f.cipher == "chachapoly" { cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) } diff --git a/inside.go b/inside.go index 21e2ab7..457fcac 100644 --- a/inside.go +++ b/inside.go @@ -6,6 +6,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/udp" ) @@ -256,6 +257,11 @@ func (f *Interface) SendVia(viaIfc interface{}, ) { via := viaIfc.(*HostInfo) relay := relayIfc.(*Relay) + + if noiseutil.EncryptLockNeeded { + // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check + via.ConnectionState.writeLock.Lock() + } c := via.ConnectionState.messageCounter.Add(1) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) @@ -264,6 +270,9 @@ func (f *Interface) SendVia(viaIfc interface{}, // Authenticate the header and payload, but do not encrypt for this message type. // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. if len(out)+len(ad)+via.ConnectionState.eKey.Overhead() > cap(out) { + if noiseutil.EncryptLockNeeded { + via.ConnectionState.writeLock.Unlock() + } via.logger(f.l). WithField("outCap", cap(out)). WithField("payloadLen", len(ad)). @@ -285,6 +294,9 @@ func (f *Interface) SendVia(viaIfc interface{}, var err error out, err = via.ConnectionState.eKey.EncryptDanger(out, out, nil, c, nb) + if noiseutil.EncryptLockNeeded { + via.ConnectionState.writeLock.Unlock() + } if err != nil { via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") return @@ -313,8 +325,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType out = out[header.Len:] } - //TODO: enable if we do more than 1 tun queue - //ci.writeLock.Lock() + if noiseutil.EncryptLockNeeded { + // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check + ci.writeLock.Lock() + } c := ci.messageCounter.Add(1) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) @@ -335,8 +349,9 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType var err error out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) - //TODO: see above note on lock - //ci.writeLock.Unlock() + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).WithField("counter", c). diff --git a/noiseutil/boring.go b/noiseutil/boring.go new file mode 100644 index 0000000..e9ad19b --- /dev/null +++ b/noiseutil/boring.go @@ -0,0 +1,80 @@ +//go:build boringcrypto +// +build boringcrypto + +package noiseutil + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + + // unsafe needed for go:linkname + _ "unsafe" + + "github.com/flynn/noise" +) + +// EncryptLockNeeded indicates if calls to Encrypt need a lock +// This is true for boringcrypto because the Seal function verifies that the +// nonce is strictly increasing. +const EncryptLockNeeded = true + +// NewGCMTLS is no longer exposed in go1.19+, so we need to link it in +// See: https://github.com/golang/go/issues/56326 +// +// NewGCMTLS is the internal method used with boringcrypto that provices a +// validated mode of AES-GCM which enforces the nonce is strictly +// monotonically increasing. This is the TLS 1.2 specification for nonce +// generation (which also matches the method used by the Noise Protocol) +// +// - https://github.com/golang/go/blob/go1.19/src/crypto/tls/cipher_suites.go#L520-L522 +// - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L235-L237 +// - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L250 +// - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/include/openssl/aead.h#L379-L381 +// - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/crypto/fipsmodule/cipher/e_aes.c#L1082-L1093 +// +//go:linkname newGCMTLS crypto/internal/boring.NewGCMTLS +func newGCMTLS(c cipher.Block) (cipher.AEAD, error) + +type cipherFn struct { + fn func([32]byte) noise.Cipher + name string +} + +func (c cipherFn) Cipher(k [32]byte) noise.Cipher { return c.fn(k) } +func (c cipherFn) CipherName() string { return c.name } + +// CipherAESGCM is the AES256-GCM AEAD cipher (using NewGCMTLS when GoBoring is present) +var CipherAESGCM noise.CipherFunc = cipherFn{cipherAESGCMBoring, "AESGCM"} + +func cipherAESGCMBoring(k [32]byte) noise.Cipher { + c, err := aes.NewCipher(k[:]) + if err != nil { + panic(err) + } + gcm, err := newGCMTLS(c) + if err != nil { + panic(err) + } + return aeadCipher{ + gcm, + func(n uint64) []byte { + var nonce [12]byte + binary.BigEndian.PutUint64(nonce[4:], n) + return nonce[:] + }, + } +} + +type aeadCipher struct { + cipher.AEAD + nonce func(uint64) []byte +} + +func (c aeadCipher) Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte { + return c.Seal(out, c.nonce(n), plaintext, ad) +} + +func (c aeadCipher) Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error) { + return c.Open(out, c.nonce(n), ciphertext, ad) +} diff --git a/noiseutil/boring_test.go b/noiseutil/boring_test.go new file mode 100644 index 0000000..bc5ff50 --- /dev/null +++ b/noiseutil/boring_test.go @@ -0,0 +1,39 @@ +//go:build boringcrypto +// +build boringcrypto + +package noiseutil + +import ( + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Ensure NewGCMTLS validates the nonce is non-repeating +func TestNewGCMTLS(t *testing.T) { + // Test Case 16 from GCM Spec: + // - (now dead link): http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-spec.pdf + // - as listed in boringssl tests: https://github.com/google/boringssl/blob/fips-20220613/crypto/cipher_extra/test/cipher_tests.txt#L412-L418 + key, _ := hex.DecodeString("feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308") + iv, _ := hex.DecodeString("cafebabefacedbaddecaf888") + plaintext, _ := hex.DecodeString("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39") + aad, _ := hex.DecodeString("feedfacedeadbeeffeedfacedeadbeefabaddad2") + expected, _ := hex.DecodeString("522dc1f099567d07f47f37a32a84427d643a8cdcbfe5c0c97598a2bd2555d1aa8cb08e48590dbb3da7b08b1056828838c5f61e6393ba7a0abcc9f662") + expectedTag, _ := hex.DecodeString("76fc6ece0f4e1768cddf8853bb2d551b") + + expected = append(expected, expectedTag...) + + var keyArray [32]byte + copy(keyArray[:], key) + c := CipherAESGCM.Cipher(keyArray) + aead := c.(aeadCipher).AEAD + + dst := aead.Seal([]byte{}, iv, plaintext, aad) + assert.Equal(t, expected, dst) + + // We expect this to fail since we are re-encrypting with a repeat IV + assert.PanicsWithError(t, "boringcrypto: EVP_AEAD_CTX_seal failed", func() { + dst = aead.Seal([]byte{}, iv, plaintext, aad) + }) +} diff --git a/noiseutil/notboring.go b/noiseutil/notboring.go new file mode 100644 index 0000000..be746f4 --- /dev/null +++ b/noiseutil/notboring.go @@ -0,0 +1,14 @@ +//go:build !boringcrypto +// +build !boringcrypto + +package noiseutil + +import ( + "github.com/flynn/noise" +) + +// EncryptLockNeeded indicates if calls to Encrypt need a lock +const EncryptLockNeeded = false + +// CipherAESGCM is the standard noise.CipherAESGCM when boringcrypto is not enabled +var CipherAESGCM noise.CipherFunc = noise.CipherAESGCM diff --git a/noiseutil/notboring_test.go b/noiseutil/notboring_test.go new file mode 100644 index 0000000..a27dbbd --- /dev/null +++ b/noiseutil/notboring_test.go @@ -0,0 +1,15 @@ +//go:build !boringcrypto +// +build !boringcrypto + +package noiseutil + +import ( + // NOTE: We have to force these imports here or boring_test.go fails to + // compile correctly. This seems to be a Go bug: + // + // $ GOEXPERIMENT=boringcrypto go test ./noiseutil + // # github.com/slackhq/nebula/noiseutil + // boring_test.go:10:2: cannot find package + + _ "github.com/stretchr/testify/assert" +) From 3cb4e0ef57a588d95a7f553017121a2544e483cd Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 5 Apr 2023 11:29:26 -0500 Subject: [PATCH 6/8] Allow listen.host to contain names (#825) --- examples/config.yml | 2 +- main.go | 15 ++++++++++++++- udp/udp_generic.go | 4 ++-- udp/udp_linux.go | 4 ++-- udp/udp_tester.go | 4 ++-- 5 files changed, 21 insertions(+), 8 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index 444592f..f8930af 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -107,7 +107,7 @@ lighthouse: # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: - # To listen on both any ipv4 and ipv6 use "[::]" + # To listen on both any ipv4 and ipv6 use "::" host: 0.0.0.0 port: 4242 # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) diff --git a/main.go b/main.go index f9ea77c..bbf831a 100644 --- a/main.go +++ b/main.go @@ -151,8 +151,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg port := c.GetInt("listen.port", 0) if !configTest { + rawListenHost := c.GetString("listen.host", "0.0.0.0") + var listenHost *net.IPAddr + if rawListenHost == "[::]" { + // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. + listenHost = &net.IPAddr{IP: net.IPv6zero} + + } else { + listenHost, err = net.ResolveIPAddr("ip", rawListenHost) + if err != nil { + return nil, util.NewContextualError("Failed to resolve listen.host", nil, err) + } + } + for i := 0; i < routines; i++ { - udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) + udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 0a7c0d9..f03174d 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -23,9 +23,9 @@ type Conn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) { +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) { lc := NewListenConfig(multi) - pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port)) + pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { return nil, err } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 5d4b16a..77102ab 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -45,7 +45,7 @@ const ( type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 -func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) { +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) { syscall.ForkLock.RLock() fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err == nil { @@ -59,7 +59,7 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) ( } var lip [16]byte - copy(lip[:], net.ParseIP(ip)) + copy(lip[:], ip.To16()) if multi { if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 62e4f56..3b33f0d 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -45,9 +45,9 @@ type Conn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) { +func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, error) { return &Conn{ - Addr: &Addr{net.ParseIP(ip), uint16(port)}, + Addr: &Addr{ip, uint16(port)}, RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, From 9b030531910847a1fa448263a1cb1a41330299f9 Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Fri, 7 Apr 2023 14:28:37 -0400 Subject: [PATCH 7/8] update EncReader and EncWriter interface function args to have concrete types (#844) * Update LightHouseHandlerFunc to remove EncWriter param. * Move EncWriter to interface * EncReader, too --- handshake.go | 2 +- handshake_ix.go | 25 +++++++++++-------------- handshake_manager.go | 6 +++--- handshake_manager_test.go | 2 +- inside.go | 7 ++----- interface.go | 15 ++++++++++++++- lighthouse.go | 22 ++++++++++++++-------- lighthouse_test.go | 2 +- outside.go | 20 ++++++++++++++++++-- udp/conn.go | 1 - udp/temp.go | 15 +-------------- udp/udp_generic.go | 2 +- udp/udp_linux.go | 2 +- udp/udp_tester.go | 2 +- 14 files changed, 69 insertions(+), 54 deletions(-) diff --git a/handshake.go b/handshake.go index 1cad0db..1f2f03a 100644 --- a/handshake.go +++ b/handshake.go @@ -5,7 +5,7 @@ import ( "github.com/slackhq/nebula/udp" ) -func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) { +func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) { // First remote allow list check before we know the vpnIp if addr != nil { if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { diff --git a/handshake_ix.go b/handshake_ix.go index a51fb31..b6b5658 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -68,7 +68,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hostinfo.handshakeStart = time.Now() } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) @@ -240,14 +240,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b } return } else { - via2 := via.(*ViaSender) - if via2 == nil { + if via == nil { f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) - f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp). + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return @@ -315,14 +314,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b Info("Handshake message sent") } } else { - via2 := via.(*ViaSender) - if via2 == nil { + if via == nil { f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) - f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp). + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -338,7 +336,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool { if hostinfo == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -482,8 +480,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * if addr != nil { hostinfo.SetRemote(addr) } else { - via2 := via.(*ViaSender) - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) } // Build up the radix for the firewall if we have subnets in the cert diff --git a/handshake_manager.go b/handshake_manager.go index c8a01ca..ce2811b 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -73,7 +73,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ } } -func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { +func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { clockSource := time.NewTicker(c.config.tryInterval) defer clockSource.Stop() @@ -89,7 +89,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) { +func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { c.OutboundHandshakeTimer.Advance(now) for { vpnIp, has := c.OutboundHandshakeTimer.Purge() @@ -100,7 +100,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E } } -func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { +func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) { hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { return diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 84b8ef6..3be8a1b 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess return } -func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { +func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { return } diff --git a/inside.go b/inside.go index 457fcac..9c40251 100644 --- a/inside.go +++ b/inside.go @@ -248,16 +248,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C // nb is a buffer used to store the nonce value, re-used for performance reasons. // out is a buffer used to store the result of the Encrypt operation // q indicates which writer to use to send the packet. -func (f *Interface) SendVia(viaIfc interface{}, - relayIfc interface{}, +func (f *Interface) SendVia(via *HostInfo, + relay *Relay, ad, nb, out []byte, nocopy bool, ) { - via := viaIfc.(*HostInfo) - relay := relayIfc.(*Relay) - if noiseutil.EncryptLockNeeded { // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check via.ConnectionState.writeLock.Lock() diff --git a/interface.go b/interface.go index af83abc..e87f9f9 100644 --- a/interface.go +++ b/interface.go @@ -16,6 +16,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" @@ -89,6 +90,18 @@ type Interface struct { l *logrus.Logger } +type EncWriter interface { + SendVia(via *HostInfo, + relay *Relay, + ad, + nb, + out []byte, + nocopy bool, + ) + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + Handshake(vpnIp iputil.VpnIp) +} + type sendRecvErrorConfig uint8 const ( @@ -238,7 +251,7 @@ func (f *Interface) listenOut(i int) { lhh := f.lightHouse.NewRequestHandler() conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i) + li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { diff --git a/lighthouse.go b/lighthouse.go index 5b34a3e..d6b6a5f 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -65,7 +65,7 @@ type LightHouse struct { interval atomic.Int64 updateCancel context.CancelFunc updateParentCtx context.Context - updateUdp udp.EncWriter + updateUdp EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 advertiseAddrs atomic.Pointer[[]netIpAndPort] @@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { +func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip, f) } @@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { } // This is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) { +func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) { if lh.amLighthouse { return } @@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } -func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { +func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) { lh.updateParentCtx = ctx lh.updateUdp = f @@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { } } -func (lh *LightHouse) SendUpdate(f udp.EncWriter) { +func (lh *LightHouse) SendUpdate(f EncWriter) { var v4 []*Ip4AndPort var v6 []*Ip6AndPort @@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) { +func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { + return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { + lhh.HandleRequest(rAddr, vpnIp, p, f) + } +} + +func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Unlock() } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } diff --git a/lighthouse_test.go b/lighthouse_test.go index e5a1692..1824463 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -372,7 +372,7 @@ type testEncWriter struct { metaFilter *NebulaMeta_MessageType } -func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { +func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { } diff --git a/outside.go b/outside.go index fd6f0a3..8361ce3 100644 --- a/outside.go +++ b/outside.go @@ -21,7 +21,23 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func readOutsidePackets(f *Interface) udp.EncReader { + return func( + addr *udp.Addr, + out []byte, + packet []byte, + header *header.H, + fwPacket *firewall.Packet, + lhh udp.LightHouseHandlerFunc, + nb []byte, + q int, + localCache firewall.ConntrackCache, + ) { + f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) + } +} + +func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by return } - lhf(addr, hostinfo.vpnIp, d, f) + lhf(addr, hostinfo.vpnIp, d) // Fallthrough to the bottom to record incoming traffic diff --git a/udp/conn.go b/udp/conn.go index fa52fe5..f967a9a 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -9,7 +9,6 @@ const MTU = 9001 type EncReader func( addr *Addr, - via interface{}, out []byte, packet []byte, header *header.H, diff --git a/udp/temp.go b/udp/temp.go index 5cc8c1c..2efe31d 100644 --- a/udp/temp.go +++ b/udp/temp.go @@ -1,22 +1,9 @@ package udp import ( - "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" ) -type EncWriter interface { - SendVia(via interface{}, - relay interface{}, - ad, - nb, - out []byte, - nocopy bool, - ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) -} - //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) +type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) diff --git a/udp/udp_generic.go b/udp/udp_generic.go index f03174d..ff254eb 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall udpAddr.IP = rua.IP udpAddr.Port = uint16(rua.Port) - r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 77102ab..26bbe36 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall for i := 0; i < n; i++ { udpAddr.IP = names[i][8:24] udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 3b33f0d..8b5e531 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall } ua.Port = p.FromPort copy(ua.IP, p.FromIp.To16()) - r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } From 397fe5f8797e3386db3db99376b716da2566e7b8 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 10 Apr 2023 12:32:37 -0500 Subject: [PATCH 8/8] Add ability to skip installing unsafe routes on the os routing table (#831) --- examples/config.yml | 9 ++++++--- overlay/route.go | 28 ++++++++++++++++++++-------- overlay/route_test.go | 22 +++++++++++++++++----- overlay/tun_darwin.go | 2 +- overlay/tun_freebsd.go | 2 +- overlay/tun_linux.go | 4 ++++ overlay/tun_water_windows.go | 2 +- overlay/tun_wintun_windows.go | 2 +- 8 files changed, 51 insertions(+), 20 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index f8930af..db5d0e3 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -204,21 +204,24 @@ tun: tx_queue: 500 # Default MTU for every packet, safe setting is (and the default) 1300 for internet based traffic mtu: 1300 + # Route based MTU overrides, you have known vpn ip paths that can support larger MTUs you can increase/decrease them here routes: #- mtu: 8800 # route: 10.0.0.0/16 + # Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate - # `mtu` will default to tun mtu if this option is not specified - # `metric` will default to 0 if this option is not specified + # `mtu`: will default to tun mtu if this option is not specified + # `metric`: will default to 0 if this option is not specified + # `install`: will default to true, controls whether this route is installed in the systems routing table. unsafe_routes: #- route: 172.16.1.0/24 # via: 192.168.100.99 # mtu: 1300 # metric: 100 - + # install: true # TODO # Configure logging level diff --git a/overlay/route.go b/overlay/route.go index e8626bb..41c7a9c 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -14,10 +14,11 @@ import ( ) type Route struct { - MTU int - Metric int - Cidr *net.IPNet - Via *iputil.VpnIp + MTU int + Metric int + Cidr *net.IPNet + Via *iputil.VpnIp + Install bool } func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) { @@ -81,7 +82,8 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { } r := Route{ - MTU: mtu, + Install: true, + MTU: mtu, } _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) @@ -182,10 +184,20 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { viaVpnIp := iputil.Ip2VpnIp(nVia) + install := true + rInstall, ok := m["install"] + if ok { + install, err = strconv.ParseBool(fmt.Sprintf("%v", rInstall)) + if err != nil { + return nil, fmt.Errorf("entry %v.install in tun.unsafe_routes is not a boolean: %v", i+1, err) + } + } + r := Route{ - Via: &viaVpnIp, - MTU: mtu, - Metric: metric, + Via: &viaVpnIp, + MTU: mtu, + Metric: metric, + Install: install, } _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) diff --git a/overlay/route_test.go b/overlay/route_test.go index 1d4286d..f83b5c1 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -92,6 +92,8 @@ func Test_parseRoutes(t *testing.T) { tested := 0 for _, r := range routes { + assert.True(t, r.Install) + if r.MTU == 8000 { assert.Equal(t, "10.0.0.1/32", r.Cidr.String()) tested++ @@ -205,35 +207,45 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") + // bad install + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} + routes, err = parseUnsafeRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") + // happy case c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29"}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, err) - assert.Len(t, routes, 3) + assert.Len(t, routes, 4) tested := 0 for _, r := range routes { if r.MTU == 8000 { assert.Equal(t, "1.0.0.1/32", r.Cidr.String()) + assert.False(t, r.Install) tested++ } else if r.MTU == 9000 { assert.Equal(t, 9000, r.MTU) assert.Equal(t, "1.0.0.0/29", r.Cidr.String()) + assert.True(t, r.Install) tested++ } else { assert.Equal(t, 1500, r.MTU) assert.Equal(t, 1234, r.Metric) assert.Equal(t, "1.0.0.2/32", r.Cidr.String()) + assert.True(t, r.Install) tested++ } } - if tested != 3 { - t.Fatal("Did not see both unsafe_routes") + if tested != 4 { + t.Fatal("Did not see all unsafe_routes") } } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index d7b4884..6320570 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -287,7 +287,7 @@ func (t *tun) Activate() error { // Unsafe path routes for _, r := range t.Routes { - if r.Via == nil { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 0a3f722..1054228 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -86,7 +86,7 @@ func (t *tun) Activate() error { } // Unsafe path routes for _, r := range t.Routes { - if r.Via == nil { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 1406438..932b585 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -279,6 +279,10 @@ func (t tun) Activate() error { // Path routes for _, r := range t.Routes { + if !r.Install { + continue + } + nr := netlink.Route{ LinkIndex: link.Attrs().Index, Dst: r.Cidr, diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index 8e2e571..b1c28d6 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -80,7 +80,7 @@ func (t *waterTun) Activate() error { } for _, r := range t.Routes { - if r.Via == nil { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 0538849..9146c88 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -92,7 +92,7 @@ func (t *winTun) Activate() error { routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1) for _, r := range t.Routes { - if r.Via == nil { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue }