diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f0ebbb..e1c4c00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- `nebula-cert ca` now supports encrypting the CA's private key with a + passphrase. Pass `-encrypt` in order to be prompted for a passphrase. + Encryption is performed using AES-256-GCM and Argon2id for KDF. KDF + parameters default to RFC recommendations, but can be overridden via CLI + flags `-argon-memory`, `-argon-parallelism`, and `-argon-iterations`. + ## [1.6.1] - 2022-09-26 ### Fixed diff --git a/cert/cert.go b/cert/cert.go index f3df89c..216efcf 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -9,7 +9,9 @@ import ( "encoding/hex" "encoding/json" "encoding/pem" + "errors" "fmt" + "math" "net" "time" @@ -21,11 +23,12 @@ import ( const publicKeyLen = 32 const ( - CertBanner = "NEBULA CERTIFICATE" - X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" - X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" - Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" - Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" + CertBanner = "NEBULA CERTIFICATE" + X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" + X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" + EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" + Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" + Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" ) type NebulaCertificate struct { @@ -48,8 +51,21 @@ type NebulaCertificateDetails struct { InvertedGroups map[string]struct{} } +type NebulaEncryptedData struct { + EncryptionMetadata NebulaEncryptionMetadata + Ciphertext []byte +} + +type NebulaEncryptionMetadata struct { + EncryptionAlgorithm string + Argon2Parameters Argon2Parameters +} + type m map[string]interface{} +// Returned if we try to unmarshal an encrypted private key without a passphrase +var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") + // UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { if len(b) == 0 { @@ -144,6 +160,30 @@ func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte { return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key}) } +// EncryptAndMarshalX25519PrivateKey is a simple helper to encrypt and PEM encode an X25519 private key +func EncryptAndMarshalEd25519PrivateKey(b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { + ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) + if err != nil { + return nil, err + } + + b, err = proto.Marshal(&RawNebulaEncryptedData{ + EncryptionMetadata: &RawNebulaEncryptionMetadata{ + EncryptionAlgorithm: "AES-256-GCM", + Argon2Parameters: &RawNebulaArgon2Parameters{ + Version: kdfParams.version, + Memory: kdfParams.Memory, + Parallelism: uint32(kdfParams.Parallelism), + Iterations: kdfParams.Iterations, + Salt: kdfParams.salt, + }, + }, + Ciphertext: ciphertext, + }) + + return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil +} + // UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b // or an error on failure func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) { @@ -168,9 +208,13 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { if k == nil { return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") } - if k.Type != Ed25519PrivateKeyBanner { + + if k.Type == EncryptedEd25519PrivateKeyBanner { + return nil, r, ErrPrivateKeyEncrypted + } else if k.Type != Ed25519PrivateKeyBanner { return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner") } + if len(k.Bytes) != ed25519.PrivateKeySize { return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } @@ -178,6 +222,101 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { return k.Bytes, r, nil } +// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert into its +// protobuf-generated struct. +func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rned RawNebulaEncryptedData + err := proto.Unmarshal(b, &rned) + if err != nil { + return nil, err + } + + if rned.EncryptionMetadata == nil { + return nil, fmt.Errorf("encoded EncryptionMetadata was nil") + } + + if rned.EncryptionMetadata.Argon2Parameters == nil { + return nil, fmt.Errorf("encoded Argon2Parameters was nil") + } + + params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) + if err != nil { + return nil, err + } + + ned := NebulaEncryptedData{ + EncryptionMetadata: NebulaEncryptionMetadata{ + EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, + Argon2Parameters: *params, + }, + Ciphertext: rned.Ciphertext, + } + + return &ned, nil +} + +func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { + if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { + return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) + } + if params.Memory <= 0 || params.Memory > math.MaxUint32 { + return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { + return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { + return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return &Argon2Parameters{ + version: rune(params.Version), + Memory: uint32(params.Memory), + Parallelism: uint8(params.Parallelism), + Iterations: uint32(params.Iterations), + salt: params.Salt, + }, nil + +} + +// DecryptAndUnmarshalEd25519PrivateKey will try to pem decode and decrypt an Ed25519 private key with +// the given passphrase, returning any other bytes b or an error on failure +func DecryptAndUnmarshalEd25519PrivateKey(passphrase, b []byte) (ed25519.PrivateKey, []byte, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + + if k.Type != EncryptedEd25519PrivateKeyBanner { + return nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519 private key banner") + } + + ned, err := UnmarshalNebulaEncryptedData(k.Bytes) + if err != nil { + return nil, r, err + } + + var bytes []byte + switch ned.EncryptionMetadata.EncryptionAlgorithm { + case "AES-256-GCM": + bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) + if err != nil { + return nil, r, err + } + default: + return nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) + } + + if len(bytes) != ed25519.PrivateKeySize { + return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") + } + + return bytes, r, nil +} + // MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key func MarshalX25519PublicKey(b []byte) []byte { return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) diff --git a/cert/cert.pb.go b/cert/cert.pb.go index 094aefb..1aa1b4b 100644 --- a/cert/cert.pb.go +++ b/cert/cert.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.0 -// protoc v3.20.0 +// protoc v3.19.4 // source: cert.proto package cert @@ -188,6 +188,195 @@ func (x *RawNebulaCertificateDetails) GetIssuer() []byte { return nil } +type RawNebulaEncryptedData struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EncryptionMetadata *RawNebulaEncryptionMetadata `protobuf:"bytes,1,opt,name=EncryptionMetadata,proto3" json:"EncryptionMetadata,omitempty"` + Ciphertext []byte `protobuf:"bytes,2,opt,name=Ciphertext,proto3" json:"Ciphertext,omitempty"` +} + +func (x *RawNebulaEncryptedData) Reset() { + *x = RawNebulaEncryptedData{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaEncryptedData) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaEncryptedData) ProtoMessage() {} + +func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead. +func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{2} +} + +func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata { + if x != nil { + return x.EncryptionMetadata + } + return nil +} + +func (x *RawNebulaEncryptedData) GetCiphertext() []byte { + if x != nil { + return x.Ciphertext + } + return nil +} + +type RawNebulaEncryptionMetadata struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EncryptionAlgorithm string `protobuf:"bytes,1,opt,name=EncryptionAlgorithm,proto3" json:"EncryptionAlgorithm,omitempty"` + Argon2Parameters *RawNebulaArgon2Parameters `protobuf:"bytes,2,opt,name=Argon2Parameters,proto3" json:"Argon2Parameters,omitempty"` +} + +func (x *RawNebulaEncryptionMetadata) Reset() { + *x = RawNebulaEncryptionMetadata{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaEncryptionMetadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaEncryptionMetadata) ProtoMessage() {} + +func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead. +func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{3} +} + +func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string { + if x != nil { + return x.EncryptionAlgorithm + } + return "" +} + +func (x *RawNebulaEncryptionMetadata) GetArgon2Parameters() *RawNebulaArgon2Parameters { + if x != nil { + return x.Argon2Parameters + } + return nil +} + +type RawNebulaArgon2Parameters struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Version int32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` // rune in Go + Memory uint32 `protobuf:"varint,2,opt,name=memory,proto3" json:"memory,omitempty"` + Parallelism uint32 `protobuf:"varint,4,opt,name=parallelism,proto3" json:"parallelism,omitempty"` // uint8 in Go + Iterations uint32 `protobuf:"varint,3,opt,name=iterations,proto3" json:"iterations,omitempty"` + Salt []byte `protobuf:"bytes,5,opt,name=salt,proto3" json:"salt,omitempty"` +} + +func (x *RawNebulaArgon2Parameters) Reset() { + *x = RawNebulaArgon2Parameters{} + if protoimpl.UnsafeEnabled { + mi := &file_cert_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RawNebulaArgon2Parameters) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RawNebulaArgon2Parameters) ProtoMessage() {} + +func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { + mi := &file_cert_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead. +func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) { + return file_cert_proto_rawDescGZIP(), []int{4} +} + +func (x *RawNebulaArgon2Parameters) GetVersion() int32 { + if x != nil { + return x.Version + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetMemory() uint32 { + if x != nil { + return x.Memory + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetParallelism() uint32 { + if x != nil { + return x.Parallelism + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetIterations() uint32 { + if x != nil { + return x.Iterations + } + return 0 +} + +func (x *RawNebulaArgon2Parameters) GetSalt() []byte { + if x != nil { + return x.Salt + } + return nil +} + var File_cert_proto protoreflect.FileDescriptor var file_cert_proto_rawDesc = []byte{ @@ -215,9 +404,38 @@ var file_cert_proto_rawDesc = []byte{ 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, - 0x72, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, - 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, - 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, + 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, + 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x22, + 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, + 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, + 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, + 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, + 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, + 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, + 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, 0x10, 0x41, 0x72, + 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, + 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, + 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, + 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, + 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, + 0x73, 0x61, 0x6c, 0x74, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, + 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -232,18 +450,23 @@ func file_cert_proto_rawDescGZIP() []byte { return file_cert_proto_rawDescData } -var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_cert_proto_goTypes = []interface{}{ (*RawNebulaCertificate)(nil), // 0: cert.RawNebulaCertificate (*RawNebulaCertificateDetails)(nil), // 1: cert.RawNebulaCertificateDetails + (*RawNebulaEncryptedData)(nil), // 2: cert.RawNebulaEncryptedData + (*RawNebulaEncryptionMetadata)(nil), // 3: cert.RawNebulaEncryptionMetadata + (*RawNebulaArgon2Parameters)(nil), // 4: cert.RawNebulaArgon2Parameters } var file_cert_proto_depIdxs = []int32{ 1, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 3, // 1: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata + 4, // 2: cert.RawNebulaEncryptionMetadata.Argon2Parameters:type_name -> cert.RawNebulaArgon2Parameters + 3, // [3:3] is the sub-list for method output_type + 3, // [3:3] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name } func init() { file_cert_proto_init() } @@ -276,6 +499,42 @@ func file_cert_proto_init() { return nil } } + file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaEncryptedData); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaEncryptionMetadata); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RawNebulaArgon2Parameters); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -283,7 +542,7 @@ func file_cert_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_cert_proto_rawDesc, NumEnums: 0, - NumMessages: 2, + NumMessages: 5, NumExtensions: 0, NumServices: 0, }, diff --git a/cert/cert.proto b/cert/cert.proto index e135dd1..be7b132 100644 --- a/cert/cert.proto +++ b/cert/cert.proto @@ -26,4 +26,22 @@ message RawNebulaCertificateDetails { // sha-256 of the issuer certificate, if this field is blank the cert is self-signed bytes Issuer = 9; -} \ No newline at end of file +} + +message RawNebulaEncryptedData { + RawNebulaEncryptionMetadata EncryptionMetadata = 1; + bytes Ciphertext = 2; +} + +message RawNebulaEncryptionMetadata { + string EncryptionAlgorithm = 1; + RawNebulaArgon2Parameters Argon2Parameters = 2; +} + +message RawNebulaArgon2Parameters { + int32 version = 1; // rune in Go + uint32 memory = 2; + uint32 parallelism = 4; // uint8 in Go + uint32 iterations = 3; + bytes salt = 5; +} diff --git a/cert/cert_test.go b/cert/cert_test.go index 5a82741..ece9f7f 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -578,6 +578,91 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA assert.EqualError(t, err, "input did not contain a valid PEM encoded block") } +func TestDecryptAndUnmarshalEd25519PrivateKey(t *testing.T) { + passphrase := []byte("DO NOT USE THIS KEY") + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + shortKey := []byte(`# A key which, once decrypted, is too short +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 +k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe +GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs +rQr3bdH3Oy/WiYU= +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner (not encrypted) +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG +XgLvodMXZJuaFPssp+WwtA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + + keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, err := DecryptAndUnmarshalEd25519PrivateKey(passphrase, keyBundle) + assert.Nil(t, err) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest) + assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + + // Fail due to invalid banner + k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest) + assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519 private key banner") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to invalid passphrase + k, rest, err = DecryptAndUnmarshalEd25519PrivateKey([]byte("invalid passphrase"), privKey) + assert.EqualError(t, err, "invalid passphrase or corrupt private key") + assert.Nil(t, k) + assert.Equal(t, rest, []byte{}) +} + +func TestEncryptAndMarshalEd25519PrivateKey(t *testing.T) { + // Having proved that decryption works correctly above, we can test the + // encryption function produces a value which can be decrypted + passphrase := []byte("passphrase") + bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + kdfParams := NewArgon2Parameters(64*1024, 4, 3) + key, err := EncryptAndMarshalEd25519PrivateKey(bytes, passphrase, kdfParams) + assert.Nil(t, err) + + // Verify the "key" can be decrypted successfully + k, rest, err := DecryptAndUnmarshalEd25519PrivateKey(passphrase, key) + assert.Len(t, k, 64) + assert.Equal(t, rest, []byte{}) + assert.Nil(t, err) + + // EncryptAndMarshalEd25519PrivateKey does not create any errors itself +} + func TestUnmarshalX25519PrivateKey(t *testing.T) { privKey := []byte(`# A good key -----BEGIN NEBULA X25519 PRIVATE KEY----- diff --git a/cert/crypto.go b/cert/crypto.go new file mode 100644 index 0000000..94f4c48 --- /dev/null +++ b/cert/crypto.go @@ -0,0 +1,140 @@ +package cert + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" + + "golang.org/x/crypto/argon2" +) + +// KDF factors +type Argon2Parameters struct { + version rune + Memory uint32 // KiB + Parallelism uint8 + Iterations uint32 + salt []byte +} + +// Returns a new Argon2Parameters object with current version set +func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters { + return &Argon2Parameters{ + version: argon2.Version, + Memory: memory, // KiB + Parallelism: parallelism, + Iterations: iterations, + } +} + +// Encrypts data using AES-256-GCM and the Argon2id key derivation function +func aes256Encrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) { + key, err := aes256DeriveKey(passphrase, kdfParams) + if err != nil { + return nil, err + } + + // this should never happen, but since this dictates how our calls into the + // aes package behave and could be catastraphic, let's sanity check this + if len(key) != 32 { + return nil, fmt.Errorf("invalid AES-256 key length (%d) - cowardly refusing to encrypt", len(key)) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nil, nonce, data, nil) + blob := joinNonceCiphertext(nonce, ciphertext) + + return blob, nil +} + +// Decrypts data using AES-256-GCM and the Argon2id key derivation function +// Expects the data to include an Argon2id parameter string before the encrypted data +func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) { + key, err := aes256DeriveKey(passphrase, kdfParams) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + + nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize()) + if err != nil { + return nil, err + } + + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("invalid passphrase or corrupt private key") + } + + return plaintext, nil +} + +func aes256DeriveKey(passphrase []byte, params *Argon2Parameters) ([]byte, error) { + if params.salt == nil { + params.salt = make([]byte, 32) + if _, err := rand.Read(params.salt); err != nil { + return nil, err + } + } + + // keySize of 32 bytes will result in AES-256 encryption + key, err := deriveKey(passphrase, 32, params) + if err != nil { + return nil, err + } + + return key, nil +} + +// Derives a key from a passphrase using Argon2id +func deriveKey(passphrase []byte, keySize uint32, params *Argon2Parameters) ([]byte, error) { + if params.version != argon2.Version { + return nil, fmt.Errorf("incompatible Argon2 version: %d", params.version) + } + + if params.salt == nil { + return nil, fmt.Errorf("salt must be set in argon2Parameters") + } else if len(params.salt) < 16 { + return nil, fmt.Errorf("salt must be at least 128 bits") + } + + key := argon2.IDKey(passphrase, params.salt, params.Iterations, params.Memory, params.Parallelism, keySize) + + return key, nil +} + +// Prepends nonce to ciphertext +func joinNonceCiphertext(nonce []byte, ciphertext []byte) []byte { + return append(nonce, ciphertext...) +} + +// Splits nonce from ciphertext +func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) { + if len(blob) <= nonceSize { + return nil, nil, fmt.Errorf("invalid ciphertext blob - blob shorter than nonce length") + } + + return blob[:nonceSize], blob[nonceSize:], nil +} diff --git a/cert/crypto_test.go b/cert/crypto_test.go new file mode 100644 index 0000000..c2e61df --- /dev/null +++ b/cert/crypto_test.go @@ -0,0 +1,25 @@ +package cert + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/argon2" +) + +func TestNewArgon2Parameters(t *testing.T) { + p := NewArgon2Parameters(64*1024, 4, 3) + assert.EqualValues(t, &Argon2Parameters{ + version: argon2.Version, + Memory: 64 * 1024, + Parallelism: 4, + Iterations: 3, + }, p) + p = NewArgon2Parameters(2*1024*1024, 2, 1) + assert.EqualValues(t, &Argon2Parameters{ + version: argon2.Version, + Memory: 2 * 1024 * 1024, + Parallelism: 2, + Iterations: 1, + }, p) +} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index ce8d5fa..b4f25c9 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "os" "strings" @@ -17,15 +18,19 @@ import ( ) type caFlags struct { - set *flag.FlagSet - name *string - duration *time.Duration - outKeyPath *string - outCertPath *string - outQRPath *string - groups *string - ips *string - subnets *string + set *flag.FlagSet + name *string + duration *time.Duration + outKeyPath *string + outCertPath *string + outQRPath *string + groups *string + ips *string + subnets *string + argonMemory *uint + argonIterations *uint + argonParallelism *uint + encryption *bool } func newCaFlags() *caFlags { @@ -39,10 +44,28 @@ func newCaFlags() *caFlags { cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") + cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") + cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") + cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") + cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") return &cf } -func ca(args []string, out io.Writer, errOut io.Writer) error { +func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert.Argon2Parameters, error) { + if memory <= 0 || memory > math.MaxUint32 { + return nil, newHelpErrorf("-argon-memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if parallelism <= 0 || parallelism > math.MaxUint8 { + return nil, newHelpErrorf("-argon-parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if iterations <= 0 || iterations > math.MaxUint32 { + return nil, newHelpErrorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil +} + +func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { cf := newCaFlags() err := cf.set.Parse(args) if err != nil { @@ -58,6 +81,12 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err } + var kdfParams *cert.Argon2Parameters + if *cf.encryption { + if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil { + return err + } + } if *cf.duration <= 0 { return &helpError{"-duration must be greater than 0"} @@ -109,6 +138,28 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { } } + var passphrase []byte + if *cf.encryption { + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if err == ErrNoTerminal { + return fmt.Errorf("out-key must be encrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading passphrase: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + + if len(passphrase) == 0 { + return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext") + } + } + pub, rawPriv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return fmt.Errorf("error while generating ed25519 keys: %s", err) @@ -140,7 +191,17 @@ func ca(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while signing: %s", err) } - err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600) + if *cf.encryption { + b, err := cert.EncryptAndMarshalEd25519PrivateKey(rawPriv, passphrase, kdfParams) + if err != nil { + return fmt.Errorf("error while encrypting out-key: %s", err) + } + + err = ioutil.WriteFile(*cf.outKeyPath, b, 0600) + } else { + err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600) + } + if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 372a4f1..0ce9182 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -5,8 +5,11 @@ package main import ( "bytes" + "encoding/pem" + "errors" "io/ioutil" "os" + "strings" "testing" "time" @@ -26,8 +29,16 @@ func Test_caHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" ca : create a self signed certificate authority\n"+ + " -argon-iterations uint\n"+ + " \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+ + " -argon-memory uint\n"+ + " \tOptional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase (default 2097152)\n"+ + " -argon-parallelism uint\n"+ + " \tOptional: Argon2 parallelism parameter used for encrypted private key passphrase (default 4)\n"+ " -duration duration\n"+ " \tOptional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\" (default 8760h0m0s)\n"+ + " -encrypt\n"+ + " \tOptional: prompt for passphrase and write out-key in an encrypted format\n"+ " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ @@ -50,18 +61,38 @@ func Test_ca(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} + nopw := &StubPasswordReader{ + password: []byte(""), + err: nil, + } + + errpw := &StubPasswordReader{ + password: []byte(""), + err: errors.New("stub error"), + } + + passphrase := []byte("DO NOT USE THIS KEY") + testpw := &StubPasswordReader{ + password: passphrase, + err: nil, + } + + pwPromptOb := "Enter passphrase: " + // required args - assertHelpError(t, ca([]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb), "-name is required") + assertHelpError(t, ca( + []string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, + ), "-name is required") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only ips - assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only subnets - assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -69,7 +100,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, ca(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -82,7 +113,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -96,7 +127,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb)) + assert.Nil(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -122,19 +153,65 @@ func Test_ca(t *testing.T) { assert.Equal(t, "", lCrt.Details.Issuer) assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey)) + // test encrypted key + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.Nil(t, ca(args, ob, eb, testpw)) + assert.Equal(t, pwPromptOb, ob.String()) + assert.Equal(t, "", eb.String()) + + // read encrypted key file and verify default params + rb, _ = ioutil.ReadFile(keyF.Name()) + k, _ := pem.Decode(rb) + ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) + assert.Nil(t, err) + // we won't know salt in advance, so just check start of string + assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) + assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) + assert.Equal(t, uint32(1), ned.EncryptionMetadata.Argon2Parameters.Iterations) + + // verify the key is valid and decrypt-able + lKey, b, err = cert.DecryptAndUnmarshalEd25519PrivateKey(passphrase, rb) + assert.Nil(t, err) + assert.Len(t, b, 0) + assert.Len(t, lKey, 64) + + // test when reading passsword results in an error + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.Error(t, ca(args, ob, eb, errpw)) + assert.Equal(t, pwPromptOb, ob.String()) + assert.Equal(t, "", eb.String()) + + // test when user fails to enter a password + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") + assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up + assert.Equal(t, "", eb.String()) + // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb)) + assert.Nil(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA key: "+keyF.Name()) + assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -143,7 +220,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA cert: "+crtF.Name()) + assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) os.Remove(keyF.Name()) diff --git a/cmd/nebula-cert/main.go b/cmd/nebula-cert/main.go index 3fba40a..b803d30 100644 --- a/cmd/nebula-cert/main.go +++ b/cmd/nebula-cert/main.go @@ -62,11 +62,11 @@ func main() { switch args[0] { case "ca": - err = ca(args[1:], os.Stdout, os.Stderr) + err = ca(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{}) case "keygen": err = keygen(args[1:], os.Stdout, os.Stderr) case "sign": - err = signCert(args[1:], os.Stdout, os.Stderr) + err = signCert(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{}) case "print": err = printCert(args[1:], os.Stdout, os.Stderr) case "verify": diff --git a/cmd/nebula-cert/passwords.go b/cmd/nebula-cert/passwords.go new file mode 100644 index 0000000..8129560 --- /dev/null +++ b/cmd/nebula-cert/passwords.go @@ -0,0 +1,28 @@ +package main + +import ( + "errors" + "fmt" + "os" + + "golang.org/x/term" +) + +var ErrNoTerminal = errors.New("cannot read password from nonexistent terminal") + +type PasswordReader interface { + ReadPassword() ([]byte, error) +} + +type StdinPasswordReader struct{} + +func (pr StdinPasswordReader) ReadPassword() ([]byte, error) { + if !term.IsTerminal(int(os.Stdin.Fd())) { + return nil, ErrNoTerminal + } + + password, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() + + return password, err +} diff --git a/cmd/nebula-cert/passwords_test.go b/cmd/nebula-cert/passwords_test.go new file mode 100644 index 0000000..d0b64b9 --- /dev/null +++ b/cmd/nebula-cert/passwords_test.go @@ -0,0 +1,10 @@ +package main + +type StubPasswordReader struct { + password []byte + err error +} + +func (pr *StubPasswordReader) ReadPassword() ([]byte, error) { + return pr.password, pr.err +} diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 4b3b899..68104ad 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -1,6 +1,7 @@ package main import ( + "crypto/ed25519" "crypto/rand" "flag" "fmt" @@ -49,7 +50,7 @@ func newSignFlags() *signFlags { } -func signCert(args []string, out io.Writer, errOut io.Writer) error { +func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { sf := newSignFlags() err := sf.set.Parse(args) if err != nil { @@ -77,8 +78,36 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while reading ca-key: %s", err) } - caKey, _, err := cert.UnmarshalEd25519PrivateKey(rawCAKey) - if err != nil { + var caKey ed25519.PrivateKey + + // naively attempt to decode the private key as though it is not encrypted + caKey, _, err = cert.UnmarshalEd25519PrivateKey(rawCAKey) + if err == cert.ErrPrivateKeyEncrypted { + // ask for a passphrase until we get one + var passphrase []byte + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if err == ErrNoTerminal { + return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading password: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + if len(passphrase) == 0 { + return fmt.Errorf("cannot open encrypted ca-key without passphrase") + } + + caKey, _, err = cert.DecryptAndUnmarshalEd25519PrivateKey(passphrase, rawCAKey) + if err != nil { + return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + } + } else if err != nil { return fmt.Errorf("error while parsing ca-key: %s", err) } diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 4976fa3..afde357 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -6,6 +6,7 @@ package main import ( "bytes" "crypto/rand" + "errors" "io/ioutil" "os" "testing" @@ -58,17 +59,39 @@ func Test_signCert(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} + nopw := &StubPasswordReader{ + password: []byte(""), + err: nil, + } + + errpw := &StubPasswordReader{ + password: []byte(""), + err: errors.New("stub error"), + } + + passphrase := []byte("DO NOT USE THIS KEY") + testpw := &StubPasswordReader{ + password: passphrase, + err: nil, + } + // required args - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-name is required") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-ip is required") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-ip is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key - assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb), "cannot set both -in-pub and -out-key") + assertHelpError(t, signCert( + []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, + ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -76,7 +99,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-key: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key ob.Reset() @@ -86,7 +109,7 @@ func Test_signCert(t *testing.T) { defer os.Remove(caKeyF.Name()) args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-key: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -98,7 +121,7 @@ func Test_signCert(t *testing.T) { // failed to read cert args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-crt: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -110,7 +133,7 @@ func Test_signCert(t *testing.T) { defer os.Remove(caCrtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-crt: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -129,7 +152,7 @@ func Test_signCert(t *testing.T) { // failed to read pub args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while reading in-pub: open ./nope: "+NoSuchFileError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -141,7 +164,7 @@ func Test_signCert(t *testing.T) { defer os.Remove(inPubF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb), "error while parsing in-pub: input did not contain a valid PEM encoded block") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -155,14 +178,14 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -170,14 +193,14 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: invalid CIDR address: a") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -191,7 +214,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to sign, root certificate does not match private key") + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -199,7 +222,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -212,7 +235,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) @@ -226,7 +249,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -268,7 +291,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -283,7 +306,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -291,14 +314,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing key: "+keyF.Name()) + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -306,14 +329,83 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb)) + assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing cert: "+crtF.Name()) + assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) + + // create valid cert/key using encrypted CA key + os.Remove(caKeyF.Name()) + os.Remove(caCrtF.Name()) + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + + caKeyF, err = ioutil.TempFile("", "sign-cert.key") + assert.Nil(t, err) + defer os.Remove(caKeyF.Name()) + + caCrtF, err = ioutil.TempFile("", "sign-cert.crt") + assert.Nil(t, err) + defer os.Remove(caCrtF.Name()) + + // generate the encrypted key + caPub, caPriv, _ = ed25519.GenerateKey(rand.Reader) + kdfParams := cert.NewArgon2Parameters(64*1024, 4, 3) + b, _ = cert.EncryptAndMarshalEd25519PrivateKey(caPriv, passphrase, kdfParams) + caKeyF.Write(b) + + ca = cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "ca", + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Minute * 200), + PublicKey: caPub, + IsCA: true, + }, + } + b, _ = ca.MarshalToPEM() + caCrtF.Write(b) + + // test with the proper password + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Nil(t, signCert(args, ob, eb, testpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test with the wrong password + ob.Reset() + eb.Reset() + + testpw.password = []byte("invalid password") + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, testpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test with the user not entering a password + ob.Reset() + eb.Reset() + + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, nopw)) + // normally the user hitting enter on the prompt would add newlines between these + assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) + + // test an error condition + ob.Reset() + eb.Reset() + + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Error(t, signCert(args, ob, eb, errpw)) + assert.Equal(t, "Enter passphrase: ", ob.String()) + assert.Empty(t, eb.String()) } diff --git a/connection_manager.go b/connection_manager.go index 0ea1f75..14086ac 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -183,12 +183,6 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, return } - if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) { - // We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel - n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) - return - } - if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { // We have already sent a test packet and nothing was returned, this hostinfo is dead hostinfo.logger(n.l). @@ -205,10 +199,26 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, Debug("Tunnel status") if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { - if n.punchy.GetTargetEverything() { - // Maybe the remote is sending us packets but our NAT is blocking it and since we are configured to punch to all - // known remotes, go ahead and do that AND send a test packet + if !outTraffic { + // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. + // Just maintain NAT state if configured to do so. n.sendPunch(hostinfo) + n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + return + + } + + if n.punchy.GetTargetEverything() { + // This is similar to the old punchy behavior with a slight optimization. + // We aren't receiving traffic but we are sending it, punch on all known + // ips in case we need to re-prime NAT state + n.sendPunch(hostinfo) + } + + if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) { + // We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel + n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + return } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues diff --git a/connection_manager_test.go b/connection_manager_test.go index e05376d..3d79cb0 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -98,6 +98,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.NotContains(t, nc.in, hostinfo.localIndexId) // Do another traffic check tick, this host should be pending deletion now + nc.Out(hostinfo.localIndexId) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId) @@ -175,6 +176,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.in, hostinfo.localIndexId) // Do another traffic check tick, this host should be pending deletion now + nc.Out(hostinfo.localIndexId) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId) diff --git a/connection_state.go b/connection_state.go index 2a7be15..4f8a577 100644 --- a/connection_state.go +++ b/connection_state.go @@ -9,6 +9,7 @@ import ( "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/noiseutil" ) const ReplayWindow = 1024 @@ -28,7 +29,7 @@ type ConnectionState struct { } func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { - cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256) + cs := noise.NewCipherSuite(noise.DH25519, noiseutil.CipherAESGCM, noise.HashSHA256) if f.cipher == "chachapoly" { cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) } diff --git a/examples/config.yml b/examples/config.yml index eae2db5..dac85ee 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -107,7 +107,7 @@ lighthouse: # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: - # To listen on both any ipv4 and ipv6 use "[::]" + # To listen on both any ipv4 and ipv6 use "::" host: 0.0.0.0 port: 4242 # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) @@ -204,20 +204,24 @@ tun: tx_queue: 500 # Default MTU for every packet, safe setting is (and the default) 1300 for internet based traffic mtu: 1300 + # Route based MTU overrides, you have known vpn ip paths that can support larger MTUs you can increase/decrease them here routes: #- mtu: 8800 # route: 10.0.0.0/16 + # Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate - # `mtu` will default to tun mtu if this option is not specified - # `metric` will default to 0 if this option is not specified + # `mtu`: will default to tun mtu if this option is not specified + # `metric`: will default to 0 if this option is not specified + # `install`: will default to true, controls whether this route is installed in the systems routing table. unsafe_routes: #- route: 172.16.1.0/24 # via: 192.168.100.99 # mtu: 1300 # metric: 100 + # install: true # EXPERIMENTAL: This option may change or disappear in the future. # Multiport spreads outgoing UDP packets across multiple UDP send ports, diff --git a/go.mod b/go.mod index 2b2fafa..ea42666 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 golang.org/x/net v0.8.0 golang.org/x/sys v0.6.0 + golang.org/x/term v0.6.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.29.0 @@ -42,7 +43,6 @@ require ( github.com/prometheus/procfs v0.9.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/mod v0.9.0 // indirect - golang.org/x/term v0.6.0 // indirect golang.org/x/tools v0.7.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6d3febc..5571236 100644 --- a/go.sum +++ b/go.sum @@ -153,8 +153,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI= -golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 h1:LGJsf5LRplCck6jUCH3dBL2dmycNruWNF5xugkSlfXw= golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= diff --git a/handshake.go b/handshake.go index 1cad0db..1f2f03a 100644 --- a/handshake.go +++ b/handshake.go @@ -5,7 +5,7 @@ import ( "github.com/slackhq/nebula/udp" ) -func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) { +func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) { // First remote allow list check before we know the vpnIp if addr != nil { if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { diff --git a/handshake_ix.go b/handshake_ix.go index 7d951ca..39615b1 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -77,7 +77,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hostinfo.handshakeStart = time.Now() } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) @@ -282,14 +282,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b } return } else { - via2 := via.(*ViaSender) - if via2 == nil { + if via == nil { f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) - f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp). + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return @@ -364,14 +363,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b Info("Handshake message sent") } } else { - via2 := via.(*ViaSender) - if via2 == nil { + if via == nil { f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) - f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp). + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -387,7 +385,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool { if hostinfo == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -551,8 +549,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * if addr != nil { hostinfo.SetRemote(addr) } else { - via2 := via.(*ViaSender) - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) } // Build up the radix for the firewall if we have subnets in the cert diff --git a/handshake_manager.go b/handshake_manager.go index 7a411c5..9a20456 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -56,10 +56,6 @@ type HandshakeManager struct { multiPort MultiPortConfig udpRaw *udp.RawConn - // vpnIps is another map similar to the pending hostmap but tracks entries in the wheel instead - // this is to avoid situations where the same vpn ip enters the wheel and causes rapid fire handshaking - vpnIps map[iputil.VpnIp]struct{} - // can be used to trigger outbound handshake for the given vpnIp trigger chan iputil.VpnIp } @@ -73,7 +69,6 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ config: config, trigger: make(chan iputil.VpnIp, config.triggerBuffer), OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), - vpnIps: map[iputil.VpnIp]struct{}{}, messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -81,7 +76,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ } } -func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { +func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { clockSource := time.NewTicker(c.config.tryInterval) defer clockSource.Stop() @@ -97,7 +92,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) { +func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { c.OutboundHandshakeTimer.Advance(now) for { vpnIp, has := c.OutboundHandshakeTimer.Purge() @@ -108,10 +103,9 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E } } -func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { +func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) { hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { - delete(c.vpnIps, vpnIp) return } hostinfo.Lock() @@ -324,10 +318,7 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init) if created { - if _, ok := c.vpnIps[vpnIp]; !ok { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) - } - c.vpnIps[vpnIp] = struct{}{} + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) c.metricInitiated.Inc(1) } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 84b8ef6..3be8a1b 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess return } -func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { +func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { return } diff --git a/inside.go b/inside.go index 3804f75..198dfdd 100644 --- a/inside.go +++ b/inside.go @@ -6,6 +6,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/udp" ) @@ -247,15 +248,17 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C // nb is a buffer used to store the nonce value, re-used for performance reasons. // out is a buffer used to store the result of the Encrypt operation // q indicates which writer to use to send the packet. -func (f *Interface) SendVia(viaIfc interface{}, - relayIfc interface{}, +func (f *Interface) SendVia(via *HostInfo, + relay *Relay, ad, nb, out []byte, nocopy bool, ) { - via := viaIfc.(*HostInfo) - relay := relayIfc.(*Relay) + if noiseutil.EncryptLockNeeded { + // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check + via.ConnectionState.writeLock.Lock() + } c := via.ConnectionState.messageCounter.Add(1) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) @@ -264,6 +267,9 @@ func (f *Interface) SendVia(viaIfc interface{}, // Authenticate the header and payload, but do not encrypt for this message type. // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. if len(out)+len(ad)+via.ConnectionState.eKey.Overhead() > cap(out) { + if noiseutil.EncryptLockNeeded { + via.ConnectionState.writeLock.Unlock() + } via.logger(f.l). WithField("outCap", cap(out)). WithField("payloadLen", len(ad)). @@ -285,6 +291,9 @@ func (f *Interface) SendVia(viaIfc interface{}, var err error out, err = via.ConnectionState.eKey.EncryptDanger(out, out, nil, c, nb) + if noiseutil.EncryptLockNeeded { + via.ConnectionState.writeLock.Unlock() + } if err != nil { via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") return @@ -330,8 +339,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType out = out[header.Len:] } - //TODO: enable if we do more than 1 tun queue - //ci.writeLock.Lock() + if noiseutil.EncryptLockNeeded { + // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check + ci.writeLock.Lock() + } c := ci.messageCounter.Add(1) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) @@ -352,8 +363,9 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType var err error out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) - //TODO: see above note on lock - //ci.writeLock.Unlock() + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).WithField("counter", c). diff --git a/interface.go b/interface.go index b7ade68..966e3a5 100644 --- a/interface.go +++ b/interface.go @@ -16,6 +16,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" @@ -101,6 +102,18 @@ type MultiPortConfig struct { TxHandshakeDelay int } +type EncWriter interface { + SendVia(via *HostInfo, + relay *Relay, + ad, + nb, + out []byte, + nocopy bool, + ) + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + Handshake(vpnIp iputil.VpnIp) +} + type sendRecvErrorConfig uint8 const ( @@ -252,7 +265,7 @@ func (f *Interface) listenOut(i int) { lhh := f.lightHouse.NewRequestHandler() conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i) + li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -396,6 +409,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { var rawStats func() + certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + for { select { case <-ctx.Done(): @@ -410,6 +425,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { } rawStats() } + certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) } } } diff --git a/lighthouse.go b/lighthouse.go index 5b34a3e..d6b6a5f 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -65,7 +65,7 @@ type LightHouse struct { interval atomic.Int64 updateCancel context.CancelFunc updateParentCtx context.Context - updateUdp udp.EncWriter + updateUdp EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 advertiseAddrs atomic.Pointer[[]netIpAndPort] @@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { +func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip, f) } @@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { } // This is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) { +func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) { if lh.amLighthouse { return } @@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } -func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { +func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) { lh.updateParentCtx = ctx lh.updateUdp = f @@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { } } -func (lh *LightHouse) SendUpdate(f udp.EncWriter) { +func (lh *LightHouse) SendUpdate(f EncWriter) { var v4 []*Ip4AndPort var v6 []*Ip6AndPort @@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) { +func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { + return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { + lhh.HandleRequest(rAddr, vpnIp, p, f) + } +} + +func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Unlock() } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } diff --git a/lighthouse_test.go b/lighthouse_test.go index e5a1692..1824463 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -372,7 +372,7 @@ type testEncWriter struct { metaFilter *NebulaMeta_MessageType } -func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { +func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { } diff --git a/main.go b/main.go index f8a1fe0..2ee5099 100644 --- a/main.go +++ b/main.go @@ -151,8 +151,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg port := c.GetInt("listen.port", 0) if !configTest { + rawListenHost := c.GetString("listen.host", "0.0.0.0") + var listenHost *net.IPAddr + if rawListenHost == "[::]" { + // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. + listenHost = &net.IPAddr{IP: net.IPv6zero} + + } else { + listenHost, err = net.ResolveIPAddr("ip", rawListenHost) + if err != nil { + return nil, util.NewContextualError("Failed to resolve listen.host", nil, err) + } + } + for i := 0; i < routines; i++ { - udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64)) + udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } diff --git a/noiseutil/boring.go b/noiseutil/boring.go new file mode 100644 index 0000000..e9ad19b --- /dev/null +++ b/noiseutil/boring.go @@ -0,0 +1,80 @@ +//go:build boringcrypto +// +build boringcrypto + +package noiseutil + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + + // unsafe needed for go:linkname + _ "unsafe" + + "github.com/flynn/noise" +) + +// EncryptLockNeeded indicates if calls to Encrypt need a lock +// This is true for boringcrypto because the Seal function verifies that the +// nonce is strictly increasing. +const EncryptLockNeeded = true + +// NewGCMTLS is no longer exposed in go1.19+, so we need to link it in +// See: https://github.com/golang/go/issues/56326 +// +// NewGCMTLS is the internal method used with boringcrypto that provices a +// validated mode of AES-GCM which enforces the nonce is strictly +// monotonically increasing. This is the TLS 1.2 specification for nonce +// generation (which also matches the method used by the Noise Protocol) +// +// - https://github.com/golang/go/blob/go1.19/src/crypto/tls/cipher_suites.go#L520-L522 +// - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L235-L237 +// - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L250 +// - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/include/openssl/aead.h#L379-L381 +// - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/crypto/fipsmodule/cipher/e_aes.c#L1082-L1093 +// +//go:linkname newGCMTLS crypto/internal/boring.NewGCMTLS +func newGCMTLS(c cipher.Block) (cipher.AEAD, error) + +type cipherFn struct { + fn func([32]byte) noise.Cipher + name string +} + +func (c cipherFn) Cipher(k [32]byte) noise.Cipher { return c.fn(k) } +func (c cipherFn) CipherName() string { return c.name } + +// CipherAESGCM is the AES256-GCM AEAD cipher (using NewGCMTLS when GoBoring is present) +var CipherAESGCM noise.CipherFunc = cipherFn{cipherAESGCMBoring, "AESGCM"} + +func cipherAESGCMBoring(k [32]byte) noise.Cipher { + c, err := aes.NewCipher(k[:]) + if err != nil { + panic(err) + } + gcm, err := newGCMTLS(c) + if err != nil { + panic(err) + } + return aeadCipher{ + gcm, + func(n uint64) []byte { + var nonce [12]byte + binary.BigEndian.PutUint64(nonce[4:], n) + return nonce[:] + }, + } +} + +type aeadCipher struct { + cipher.AEAD + nonce func(uint64) []byte +} + +func (c aeadCipher) Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte { + return c.Seal(out, c.nonce(n), plaintext, ad) +} + +func (c aeadCipher) Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error) { + return c.Open(out, c.nonce(n), ciphertext, ad) +} diff --git a/noiseutil/boring_test.go b/noiseutil/boring_test.go new file mode 100644 index 0000000..bc5ff50 --- /dev/null +++ b/noiseutil/boring_test.go @@ -0,0 +1,39 @@ +//go:build boringcrypto +// +build boringcrypto + +package noiseutil + +import ( + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Ensure NewGCMTLS validates the nonce is non-repeating +func TestNewGCMTLS(t *testing.T) { + // Test Case 16 from GCM Spec: + // - (now dead link): http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-spec.pdf + // - as listed in boringssl tests: https://github.com/google/boringssl/blob/fips-20220613/crypto/cipher_extra/test/cipher_tests.txt#L412-L418 + key, _ := hex.DecodeString("feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308") + iv, _ := hex.DecodeString("cafebabefacedbaddecaf888") + plaintext, _ := hex.DecodeString("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39") + aad, _ := hex.DecodeString("feedfacedeadbeeffeedfacedeadbeefabaddad2") + expected, _ := hex.DecodeString("522dc1f099567d07f47f37a32a84427d643a8cdcbfe5c0c97598a2bd2555d1aa8cb08e48590dbb3da7b08b1056828838c5f61e6393ba7a0abcc9f662") + expectedTag, _ := hex.DecodeString("76fc6ece0f4e1768cddf8853bb2d551b") + + expected = append(expected, expectedTag...) + + var keyArray [32]byte + copy(keyArray[:], key) + c := CipherAESGCM.Cipher(keyArray) + aead := c.(aeadCipher).AEAD + + dst := aead.Seal([]byte{}, iv, plaintext, aad) + assert.Equal(t, expected, dst) + + // We expect this to fail since we are re-encrypting with a repeat IV + assert.PanicsWithError(t, "boringcrypto: EVP_AEAD_CTX_seal failed", func() { + dst = aead.Seal([]byte{}, iv, plaintext, aad) + }) +} diff --git a/noiseutil/notboring.go b/noiseutil/notboring.go new file mode 100644 index 0000000..be746f4 --- /dev/null +++ b/noiseutil/notboring.go @@ -0,0 +1,14 @@ +//go:build !boringcrypto +// +build !boringcrypto + +package noiseutil + +import ( + "github.com/flynn/noise" +) + +// EncryptLockNeeded indicates if calls to Encrypt need a lock +const EncryptLockNeeded = false + +// CipherAESGCM is the standard noise.CipherAESGCM when boringcrypto is not enabled +var CipherAESGCM noise.CipherFunc = noise.CipherAESGCM diff --git a/noiseutil/notboring_test.go b/noiseutil/notboring_test.go new file mode 100644 index 0000000..a27dbbd --- /dev/null +++ b/noiseutil/notboring_test.go @@ -0,0 +1,15 @@ +//go:build !boringcrypto +// +build !boringcrypto + +package noiseutil + +import ( + // NOTE: We have to force these imports here or boring_test.go fails to + // compile correctly. This seems to be a Go bug: + // + // $ GOEXPERIMENT=boringcrypto go test ./noiseutil + // # github.com/slackhq/nebula/noiseutil + // boring_test.go:10:2: cannot find package + + _ "github.com/stretchr/testify/assert" +) diff --git a/outside.go b/outside.go index b1e43a8..f2e7048 100644 --- a/outside.go +++ b/outside.go @@ -21,7 +21,23 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func readOutsidePackets(f *Interface) udp.EncReader { + return func( + addr *udp.Addr, + out []byte, + packet []byte, + header *header.H, + fwPacket *firewall.Packet, + lhh udp.LightHouseHandlerFunc, + nb []byte, + q int, + localCache firewall.ConntrackCache, + ) { + f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) + } +} + +func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by return } - lhf(addr, hostinfo.vpnIp, d, f) + lhf(addr, hostinfo.vpnIp, d) // Fallthrough to the bottom to record incoming traffic diff --git a/overlay/route.go b/overlay/route.go index e8626bb..41c7a9c 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -14,10 +14,11 @@ import ( ) type Route struct { - MTU int - Metric int - Cidr *net.IPNet - Via *iputil.VpnIp + MTU int + Metric int + Cidr *net.IPNet + Via *iputil.VpnIp + Install bool } func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) { @@ -81,7 +82,8 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { } r := Route{ - MTU: mtu, + Install: true, + MTU: mtu, } _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) @@ -182,10 +184,20 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { viaVpnIp := iputil.Ip2VpnIp(nVia) + install := true + rInstall, ok := m["install"] + if ok { + install, err = strconv.ParseBool(fmt.Sprintf("%v", rInstall)) + if err != nil { + return nil, fmt.Errorf("entry %v.install in tun.unsafe_routes is not a boolean: %v", i+1, err) + } + } + r := Route{ - Via: &viaVpnIp, - MTU: mtu, - Metric: metric, + Via: &viaVpnIp, + MTU: mtu, + Metric: metric, + Install: install, } _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) diff --git a/overlay/route_test.go b/overlay/route_test.go index 1d4286d..f83b5c1 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -92,6 +92,8 @@ func Test_parseRoutes(t *testing.T) { tested := 0 for _, r := range routes { + assert.True(t, r.Install) + if r.MTU == 8000 { assert.Equal(t, "10.0.0.1/32", r.Cidr.String()) tested++ @@ -205,35 +207,45 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") + // bad install + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} + routes, err = parseUnsafeRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") + // happy case c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29"}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, + map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, err) - assert.Len(t, routes, 3) + assert.Len(t, routes, 4) tested := 0 for _, r := range routes { if r.MTU == 8000 { assert.Equal(t, "1.0.0.1/32", r.Cidr.String()) + assert.False(t, r.Install) tested++ } else if r.MTU == 9000 { assert.Equal(t, 9000, r.MTU) assert.Equal(t, "1.0.0.0/29", r.Cidr.String()) + assert.True(t, r.Install) tested++ } else { assert.Equal(t, 1500, r.MTU) assert.Equal(t, 1234, r.Metric) assert.Equal(t, "1.0.0.2/32", r.Cidr.String()) + assert.True(t, r.Install) tested++ } } - if tested != 3 { - t.Fatal("Did not see both unsafe_routes") + if tested != 4 { + t.Fatal("Did not see all unsafe_routes") } } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index d7b4884..6320570 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -287,7 +287,7 @@ func (t *tun) Activate() error { // Unsafe path routes for _, r := range t.Routes { - if r.Via == nil { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 0a3f722..1054228 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -86,7 +86,7 @@ func (t *tun) Activate() error { } // Unsafe path routes for _, r := range t.Routes { - if r.Via == nil { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 1406438..932b585 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -279,6 +279,10 @@ func (t tun) Activate() error { // Path routes for _, r := range t.Routes { + if !r.Install { + continue + } + nr := netlink.Route{ LinkIndex: link.Attrs().Index, Dst: r.Cidr, diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index 8e2e571..b1c28d6 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -80,7 +80,7 @@ func (t *waterTun) Activate() error { } for _, r := range t.Routes { - if r.Via == nil { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 0538849..9146c88 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -92,7 +92,7 @@ func (t *winTun) Activate() error { routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1) for _, r := range t.Routes { - if r.Via == nil { + if r.Via == nil || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/udp/conn.go b/udp/conn.go index fa52fe5..f967a9a 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -9,7 +9,6 @@ const MTU = 9001 type EncReader func( addr *Addr, - via interface{}, out []byte, packet []byte, header *header.H, diff --git a/udp/temp.go b/udp/temp.go index 5cc8c1c..2efe31d 100644 --- a/udp/temp.go +++ b/udp/temp.go @@ -1,22 +1,9 @@ package udp import ( - "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" ) -type EncWriter interface { - SendVia(via interface{}, - relay interface{}, - ad, - nb, - out []byte, - nocopy bool, - ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) -} - //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) +type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 0a7c0d9..ff254eb 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -23,9 +23,9 @@ type Conn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) { +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) { lc := NewListenConfig(multi) - pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port)) + pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { return nil, err } @@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall udpAddr.IP = rua.IP udpAddr.Port = uint16(rua.Port) - r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 5d4b16a..26bbe36 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -45,7 +45,7 @@ const ( type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 -func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) { +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) { syscall.ForkLock.RLock() fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err == nil { @@ -59,7 +59,7 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) ( } var lip [16]byte - copy(lip[:], net.ParseIP(ip)) + copy(lip[:], ip.To16()) if multi { if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { @@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall for i := 0; i < n; i++ { udpAddr.IP = names[i][8:24] udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 62e4f56..8b5e531 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -45,9 +45,9 @@ type Conn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) { +func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, error) { return &Conn{ - Addr: &Addr{net.ParseIP(ip), uint16(port)}, + Addr: &Addr{ip, uint16(port)}, RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, @@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall } ua.Port = p.FromPort copy(ua.IP, p.FromIp.To16()) - r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } }