mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
V2 certificate format (#1216)
Co-authored-by: Nate Brown <nbrown.us@gmail.com> Co-authored-by: Jack Doan <jackdoan@rivian.com> Co-authored-by: brad-defined <77982333+brad-defined@users.noreply.github.com> Co-authored-by: Jack Doan <me@jackdoan.com>
This commit is contained in:
110
dns_server.go
110
dns_server.go
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
@@ -21,24 +22,39 @@ var dnsAddr string
|
||||
|
||||
type dnsRecords struct {
|
||||
sync.RWMutex
|
||||
dnsMap map[string]string
|
||||
hostMap *HostMap
|
||||
l *logrus.Logger
|
||||
dnsMap4 map[string]netip.Addr
|
||||
dnsMap6 map[string]netip.Addr
|
||||
hostMap *HostMap
|
||||
myVpnAddrsTable *bart.Table[struct{}]
|
||||
}
|
||||
|
||||
func newDnsRecords(hostMap *HostMap) *dnsRecords {
|
||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||
return &dnsRecords{
|
||||
dnsMap: make(map[string]string),
|
||||
hostMap: hostMap,
|
||||
l: l,
|
||||
dnsMap4: make(map[string]netip.Addr),
|
||||
dnsMap6: make(map[string]netip.Addr),
|
||||
hostMap: hostMap,
|
||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Query(data string) string {
|
||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||
data = strings.ToLower(data)
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
|
||||
return r
|
||||
switch q {
|
||||
case dns.TypeA:
|
||||
if r, ok := d.dnsMap4[data]; ok {
|
||||
return r
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if r, ok := d.dnsMap6[data]; ok {
|
||||
return r
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) QueryCert(data string) string {
|
||||
@@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
hostinfo := d.hostMap.QueryVpnIp(ip)
|
||||
hostinfo := d.hostMap.QueryVpnAddr(ip)
|
||||
if hostinfo == nil {
|
||||
return ""
|
||||
}
|
||||
@@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Add(host, data string) {
|
||||
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
||||
host = strings.ToLower(host)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.dnsMap[strings.ToLower(host)] = data
|
||||
haveV4 := false
|
||||
haveV6 := false
|
||||
for _, addr := range addresses {
|
||||
if addr.Is4() && !haveV4 {
|
||||
d.dnsMap4[host] = addr
|
||||
haveV4 = true
|
||||
} else if addr.Is6() && !haveV6 {
|
||||
d.dnsMap6[host] = addr
|
||||
haveV6 = true
|
||||
}
|
||||
if haveV4 && haveV6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
||||
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
||||
a, _, _ := net.SplitHostPort(addr)
|
||||
b, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if b.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
_, found := d.myVpnAddrsTable.Lookup(b)
|
||||
return found //if we found it in this table, it's good
|
||||
}
|
||||
|
||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
for _, q := range m.Question {
|
||||
switch q.Qtype {
|
||||
case dns.TypeA:
|
||||
l.Debugf("Query for A %s", q.Name)
|
||||
ip := dnsR.Query(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||
ip := d.Query(q.Qtype, q.Name)
|
||||
if ip.IsValid() {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||
if err == nil {
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
}
|
||||
case dns.TypeTXT:
|
||||
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
||||
b, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
// We only answer these queries from nebula nodes or localhost
|
||||
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||
return
|
||||
}
|
||||
|
||||
// We don't answer these queries from non nebula nodes or localhost
|
||||
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
|
||||
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
|
||||
return
|
||||
}
|
||||
l.Debugf("Query for TXT %s", q.Name)
|
||||
ip := dnsR.QueryCert(q.Name)
|
||||
d.l.Debugf("Query for TXT %s", q.Name)
|
||||
ip := d.QueryCert(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||
if err == nil {
|
||||
@@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
|
||||
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Compress = false
|
||||
|
||||
switch r.Opcode {
|
||||
case dns.OpcodeQuery:
|
||||
parseQuery(l, m, w)
|
||||
d.parseQuery(m, w)
|
||||
}
|
||||
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
|
||||
dnsR = newDnsRecords(hostMap)
|
||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||
dnsR = newDnsRecords(l, cs, hostMap)
|
||||
|
||||
// attach request handler func
|
||||
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
handleDnsRequest(l, w, r)
|
||||
})
|
||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
reloadDns(l, c)
|
||||
|
||||
Reference in New Issue
Block a user