Add destination CIDR checking (#507)

This commit is contained in:
Ilya Lukyanov
2023-05-09 16:37:23 +01:00
committed by GitHub
parent a9cb2e06f4
commit 1701087035
3 changed files with 169 additions and 82 deletions

View File

@@ -25,7 +25,7 @@ const tcpACK = 0x10
const tcpFIN = 0x01
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
}
type conn struct {
@@ -106,11 +106,12 @@ type FirewallCA struct {
}
type FirewallRule struct {
// Any makes Hosts, Groups, and CIDR irrelevant
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *cidr.Tree4
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *cidr.Tree4
LocalCIDR *cidr.Tree4
}
// Even though ports are uint16, int32 maps are faster for lookup
@@ -218,18 +219,22 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
}
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip != nil {
sIp = ip.String()
}
lIp := ""
if localIp != nil {
lIp = localIp.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
)
f.rules += ruleString + "\n"
@@ -237,7 +242,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
@@ -264,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto)
}
return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha)
return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
}
// GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -302,8 +307,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
}
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i)
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
}
if len(r.Groups) > 0 {
@@ -355,7 +360,15 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
}
}
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha)
var localCidr *net.IPNet
if r.LocalCidr != "" {
_, localCidr, err = net.ParseCIDR(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
}
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
}
@@ -595,7 +608,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false
}
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@@ -608,7 +621,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
}
}
if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil {
if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
return err
}
}
@@ -639,12 +652,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: cidr.NewTree4(),
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: cidr.NewTree4(),
LocalCIDR: cidr.NewTree4(),
}
}
@@ -653,14 +667,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
fc.Any = fr()
}
return fc.Any.addRule(groups, host, ip)
return fc.Any.addRule(groups, host, ip, localIp)
}
if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr()
}
err := fc.CAShas[caSha].addRule(groups, host, ip)
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
if err != nil {
return err
}
@@ -670,7 +684,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr()
}
err := fc.CANames[caName].addRule(groups, host, ip)
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
if err != nil {
return err
}
@@ -702,17 +716,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c)
}
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
if fr.Any {
return nil
}
if fr.isAny(groups, host, ip) {
if fr.isAny(groups, host, ip, localIp) {
fr.Any = true
// If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = cidr.NewTree4()
fr.LocalCIDR = cidr.NewTree4()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
@@ -725,13 +740,17 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
if ip != nil {
fr.CIDR.AddCIDR(ip, struct{}{})
}
if localIp != nil {
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
}
}
return nil
}
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil {
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
return true
}
@@ -749,6 +768,10 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return true
}
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
return true
}
return false
}
@@ -790,20 +813,25 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
return true
}
if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
return true
}
// No host, group, or cidr matched, bye bye
return false
}
type rule struct {
Port string
Code string
Proto string
Host string
Group string
Groups []string
Cidr string
CAName string
CASha string
Port string
Code string
Proto string
Host string
Group string
Groups []string
Cidr string
LocalCidr string
CAName string
CASha string
}
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
@@ -827,6 +855,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
r.Proto = toString("proto", m)
r.Host = toString("host", m)
r.Cidr = toString("cidr", m)
r.LocalCidr = toString("local_cidr", m)
r.CAName = toString("ca_name", m)
r.CASha = toString("ca_sha", m)