apic: minor refactoring (#2415)

* apic: minor refactoring

* Add whitelist length check

If user configures the file but fails to define and actual whitelist we should check length to save allocs

* Init with length from file

* extract loop method from ApplyApicWhitelists

* pass pointer

* extract loop method updateBlocklist

---------

Co-authored-by: Laurence Jones <laurence.jones@live.co.uk>
This commit is contained in:
mmetc 2023-08-10 13:03:47 +02:00 committed by GitHub
parent 93c22f29cf
commit afeb541eac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 88 deletions

View file

@ -620,59 +620,57 @@ func (a *apic) PullTop(forcePull bool) error {
return nil
}
// if decisions is whitelisted: return representation of the whitelist ip or cidr
// if not whitelisted: empty string
func (a *apic) whitelistedBy(decision *models.Decision) string {
if decision.Value == nil {
return ""
}
ipval := net.ParseIP(*decision.Value)
for _, cidr := range a.whitelists.Cidrs {
if cidr.Contains(ipval) {
return cidr.String()
}
}
for _, ip := range a.whitelists.Ips {
if ip != nil && ip.Equal(ipval) {
return ip.String()
}
}
return ""
}
func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decision {
if a.whitelists == nil {
if a.whitelists == nil || len(a.whitelists.Cidrs) == 0 && len(a.whitelists.Ips) == 0 {
return decisions
}
//deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place
outIdx := 0
for _, decision := range decisions {
if decision.Value == nil {
whitelister := a.whitelistedBy(decision)
if whitelister != "" {
log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, whitelister)
continue
}
skip := false
ipval := net.ParseIP(*decision.Value)
for _, cidr := range a.whitelists.Cidrs {
if skip {
break
}
if cidr.Contains(ipval) {
log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, cidr.String())
skip = true
}
}
for _, ip := range a.whitelists.Ips {
if skip {
break
}
if ip != nil && ip.Equal(ipval) {
log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, ip.String())
skip = true
}
}
if !skip {
decisions[outIdx] = decision
outIdx++
}
decisions[outIdx] = decision
outIdx++
}
//shrink the list, those are deleted items
decisions = decisions[:outIdx]
return decisions
return decisions[:outIdx]
}
func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) error {
for idx, alert := range alertsFromCapi {
alertsFromCapi[idx] = setAlertScenario(add_counters, delete_counters, alert)
log.Debugf("%s has %d decisions", *alertsFromCapi[idx].Source.Scope, len(alertsFromCapi[idx].Decisions))
for _, alert := range alertsFromCapi {
setAlertScenario(alert, add_counters, delete_counters)
log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions))
if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) {
log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist")
}
alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alertsFromCapi[idx])
alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert)
if err != nil {
return fmt.Errorf("while saving alert from %s: %w", *alertsFromCapi[idx].Source.Scope, err)
return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err)
}
log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alertsFromCapi[idx].Source.Scope, inserted, deleted, alertID)
log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alert.Source.Scope, inserted, deleted, alertID)
}
return nil
@ -708,6 +706,60 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
return false, nil
}
func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int) error {
if blocklist.Scope == nil {
log.Warningf("blocklist has no scope")
return nil
}
if blocklist.Duration == nil {
log.Warningf("blocklist has no duration")
return nil
}
forcePull, err := a.ShouldForcePullBlocklist(blocklist)
if err != nil {
return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
}
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
var lastPullTimestamp *string
if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
if err != nil {
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
}
decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
if err != nil {
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
}
if !hasChanged {
if lastPullTimestamp == nil {
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
} else {
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
}
return nil
}
err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
if err != nil {
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
if len(decisions) == 0 {
log.Infof("blocklist %s has no decisions", *blocklist.Name)
return nil
}
//apply APIC specific whitelists
decisions = a.ApplyApicWhitelists(decisions)
alert := createAlertForDecision(decisions[0])
alertsFromCapi := []*models.Alert{alert}
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters)
err = a.SaveAlerts(alertsFromCapi, add_counters, nil)
if err != nil {
return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err)
}
return nil
}
func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int) error {
if links == nil {
return nil
@ -722,61 +774,14 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
return fmt.Errorf("while creating default client: %w", err)
}
for _, blocklist := range links.Blocklists {
if blocklist.Scope == nil {
log.Warningf("blocklist has no scope")
continue
}
if blocklist.Duration == nil {
log.Warningf("blocklist has no duration")
continue
}
forcePull, err := a.ShouldForcePullBlocklist(blocklist)
if err != nil {
return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
}
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
var lastPullTimestamp *string
if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
if err != nil {
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
}
decisions, has_changed, err := defaultClient.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
if err != nil {
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
}
if !has_changed {
if lastPullTimestamp == nil {
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
} else {
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
}
continue
}
err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
if err != nil {
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
if len(decisions) == 0 {
log.Infof("blocklist %s has no decisions", *blocklist.Name)
continue
}
//apply APIC specific whitelists
decisions = a.ApplyApicWhitelists(decisions)
alert := createAlertForDecision(decisions[0])
alertsFromCapi := []*models.Alert{alert}
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters)
err = a.SaveAlerts(alertsFromCapi, add_counters, nil)
if err != nil {
return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err)
}
if err := a.updateBlocklist(defaultClient, blocklist, add_counters); err != nil {
return err
}
}
return nil
}
func setAlertScenario(add_counters map[string]map[string]int, delete_counters map[string]map[string]int, alert *models.Alert) *models.Alert {
func setAlertScenario(alert *models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) {
if *alert.Source.Scope == types.CAPIOrigin {
*alert.Source.Scope = SCOPE_CAPI_ALIAS_ALIAS
alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.CAPIOrigin]["all"], delete_counters[types.CAPIOrigin]["all"]))
@ -784,7 +789,6 @@ func setAlertScenario(add_counters map[string]map[string]int, delete_counters ma
*alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario)
alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.ListOrigin][*alert.Scenario], delete_counters[types.ListOrigin][*alert.Scenario]))
}
return alert
}
func (a *apic) Pull() error {

View file

@ -344,13 +344,16 @@ func (s *LocalApiServerCfg) LoadCapiWhitelists() error {
}
var fromCfg capiWhitelists
s.CapiWhitelists = &CapiWhitelist{}
defer fd.Close()
decoder := yaml.NewDecoder(fd)
if err := decoder.Decode(&fromCfg); err != nil {
return fmt.Errorf("while parsing capi whitelist file '%s': %s", s.CapiWhitelistsPath, err)
}
s.CapiWhitelists = &CapiWhitelist{
Ips: make([]net.IP, len(fromCfg.Ips)),
Cidrs: make([]*net.IPNet, len(fromCfg.Cidrs)),
}
for _, v := range fromCfg.Ips {
ip := net.ParseIP(v)
if ip == nil {