add tests for pkg/database (#151)

This commit is contained in:
Thibault "bui" Koechlin 2020-07-30 15:58:06 +02:00 committed by GitHub
parent daf2a350ea
commit 0356f8404b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 1137 additions and 121 deletions

View file

@ -3,125 +3,14 @@ package database
import (
"time"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
func (c *Context) DeleteExpired() error {
c.lock.Lock()
defer c.lock.Unlock()
//Delete the expired records
now := time.Now()
if c.flush {
retx := c.Db.Delete(types.BanApplication{}, "until < ?", now)
if retx.RowsAffected > 0 {
log.Infof("Flushed %d expired entries from Ban Application", retx.RowsAffected)
}
}
return nil
}
/*Flush doesn't do anything here : we are not using transactions or such, nothing to "flush" per se*/
func (c *Context) Flush() error {
return nil
}
func (c *Context) CleanUpRecordsByAge() error {
//let's fetch all expired records that are more than XX days olds
sos := []types.BanApplication{}
if c.maxDurationRetention == 0 {
return nil
}
//look for soft-deleted events that are OLDER than maxDurationRetention
ret := c.Db.Unscoped().Table("ban_applications").Where("deleted_at is not NULL").
Where("deleted_at < ?", time.Now().Add(-c.maxDurationRetention)).
Order("updated_at desc").Find(&sos)
if ret.Error != nil {
return errors.Wrap(ret.Error, "failed to get count of old records")
}
//no events elligible
if len(sos) == 0 || ret.RowsAffected == 0 {
log.Debugf("no event older than %s", c.maxDurationRetention.String())
return nil
}
delRecords := 0
for _, record := range sos {
copy := record
if ret := c.Db.Unscoped().Table("signal_occurences").Where("ID = ?", copy.SignalOccurenceID).Delete(&types.SignalOccurence{}); ret.Error != nil {
return errors.Wrap(ret.Error, "failed to clean signal_occurences")
}
if ret := c.Db.Unscoped().Table("event_sequences").Where("signal_occurence_id = ?", copy.SignalOccurenceID).Delete(&types.EventSequence{}); ret.Error != nil {
return errors.Wrap(ret.Error, "failed to clean event_sequences")
}
if ret := c.Db.Unscoped().Table("ban_applications").Delete(&copy); ret.Error != nil {
return errors.Wrap(ret.Error, "failed to clean ban_applications")
}
delRecords++
}
log.Printf("max_records_age: deleting %d events (max age:%s)", delRecords, c.maxDurationRetention)
return nil
}
func (c *Context) CleanUpRecordsByCount() error {
var count int
if c.maxEventRetention <= 0 {
return nil
}
ret := c.Db.Unscoped().Table("ban_applications").Order("updated_at desc").Count(&count)
if ret.Error != nil {
return errors.Wrap(ret.Error, "failed to get bans count")
}
if count < c.maxEventRetention {
log.Debugf("%d < %d, don't cleanup", count, c.maxEventRetention)
return nil
}
sos := []types.BanApplication{}
now := time.Now()
/*get soft deleted records oldest to youngest*/
//records := c.Db.Unscoped().Table("ban_applications").Where("deleted_at is not NULL").Where(`strftime("%s", deleted_at) < strftime("%s", "now")`).Find(&sos)
records := c.Db.Unscoped().Table("ban_applications").Where("deleted_at is not NULL").Where("deleted_at < ?", now).Find(&sos)
if records.Error != nil {
return errors.Wrap(records.Error, "failed to list expired bans for flush")
}
//let's do it in a single transaction
delRecords := 0
for _, ld := range sos {
copy := ld
if ret := c.Db.Unscoped().Table("signal_occurences").Where("ID = ?", copy.SignalOccurenceID).Delete(&types.SignalOccurence{}); ret.Error != nil {
return errors.Wrap(ret.Error, "failed to clean signal_occurences")
}
if ret := c.Db.Unscoped().Table("event_sequences").Where("signal_occurence_id = ?", copy.SignalOccurenceID).Delete(&types.EventSequence{}); ret.Error != nil {
return errors.Wrap(ret.Error, "failed to clean event_sequences")
}
if ret := c.Db.Unscoped().Table("ban_applications").Delete(&copy); ret.Error != nil {
return errors.Wrap(ret.Error, "failed to clean ban_applications")
}
//we need to delete associations : event_sequences, signal_occurences
delRecords++
//let's delete as well the associated event_sequence
if count-delRecords <= c.maxEventRetention {
break
}
}
if len(sos) > 0 {
log.Printf("max_records: deleting %d events. (%d soft-deleted)", delRecords, len(sos))
} else {
log.Debugf("didn't find any record to clean")
}
return nil
}
func (c *Context) StartAutoCommit() error {
//TBD : we shouldn't start auto-commit if we are in cli mode ?
c.PusherTomb.Go(func() error {
@ -151,14 +40,14 @@ func (c *Context) autoCommit() {
}
return
case <-expireTicker.C:
if err := c.DeleteExpired(); err != nil {
if _, err := c.DeleteExpired(); err != nil {
log.Errorf("Error while deleting expired records: %s", err)
}
case <-cleanUpTicker.C:
if err := c.CleanUpRecordsByCount(); err != nil {
if _, err := c.CleanUpRecordsByCount(); err != nil {
log.Errorf("error in max records cleanup : %s", err)
}
if err := c.CleanUpRecordsByAge(); err != nil {
if _, err := c.CleanUpRecordsByAge(); err != nil {
log.Errorf("error in old records cleanup : %s", err)
}

View file

@ -29,7 +29,7 @@ type Context struct {
}
func checkConfig(cfg map[string]string) error {
switch dbType, _ := cfg["type"]; dbType {
switch dbType := cfg["type"]; dbType {
case "sqlite":
if val, ok := cfg["db_path"]; !ok || val == "" {
return fmt.Errorf("please specify a 'db_path' to SQLite db in the configuration")

View file

@ -0,0 +1,157 @@
package database
import (
"net"
"testing"
"time"
"github.com/crowdsecurity/crowdsec/pkg/types"
)
func genSignalOccurence(ip string) types.SignalOccurence {
target_ip := net.ParseIP(ip)
Ban := types.BanApplication{
MeasureType: "ban",
MeasureSource: "local",
//for 10 minutes
Until: time.Now().Add(10 * time.Minute),
StartIp: types.IP2Int(target_ip),
EndIp: types.IP2Int(target_ip),
TargetCN: "FR",
TargetAS: 1234,
TargetASName: "Random AS",
IpText: target_ip.String(),
Reason: "A reason",
Scenario: "A scenario",
}
Signal := types.SignalOccurence{
MapKey: "lala",
Scenario: "old_overflow",
//a few minutes ago
Start_at: time.Now().Add(-10 * time.Minute),
Stop_at: time.Now().Add(-5 * time.Minute),
BanApplications: []types.BanApplication{Ban},
}
return Signal
}
func TestCreateDB(t *testing.T) {
var CfgTests = []struct {
cfg map[string]string
valid bool
}{
{map[string]string{
"type": "sqlite",
"db_path": "./test.db",
"max_records": "1000",
"max_records_age": "72h",
"debug": "false",
"flush": "true",
}, true},
//bad type
{map[string]string{
"type": "inexistant_DB",
"db_path": "./test.db",
"max_records": "1000",
"debug": "false",
"flush": "true",
}, false},
//missing db_path
{map[string]string{
"type": "sqlite",
"max_records": "1000",
"debug": "false",
"flush": "true",
}, false},
//valid mysql, but won't be able to connect and thus fail
{map[string]string{
"type": "mysql",
"db_host": "localhost",
"db_username": "crowdsec",
"db_password": "password",
"db_name": "crowdsec",
"max_records": "1000",
"debug": "false",
"flush": "true",
}, false},
//mysql : missing host
{map[string]string{
"type": "mysql",
"db_username": "crowdsec",
"db_password": "password",
"db_name": "crowdsec",
"max_records": "1000",
"debug": "false",
"flush": "true",
}, false},
//mysql : missing username
{map[string]string{
"type": "mysql",
"db_host": "localhost",
"db_password": "password",
"db_name": "crowdsec",
"max_records": "1000",
"debug": "false",
"flush": "true",
}, false},
//mysql : missing password
{map[string]string{
"type": "mysql",
"db_host": "localhost",
"db_username": "crowdsec",
"db_name": "crowdsec",
"max_records": "1000",
"debug": "false",
"flush": "true",
}, false},
//mysql : missing db_name
{map[string]string{
"type": "mysql",
"db_host": "localhost",
"db_username": "crowdsec",
"db_password": "password",
"max_records": "1000",
"debug": "false",
"flush": "true",
}, false},
//sqlite : bad bools
{map[string]string{
"type": "sqlite",
"db_path": "./test.db",
"max_records": "1000",
"max_records_age": "72h",
"debug": "false",
"flush": "ratata",
}, false},
}
for idx, TestCase := range CfgTests {
ctx, err := NewDatabase(TestCase.cfg)
if TestCase.valid {
if err != nil {
t.Fatalf("didn't expect error (case %d/%d) : %s", idx, len(CfgTests), err)
}
if ctx == nil {
t.Fatalf("didn't expect empty ctx (case %d/%d)", idx, len(CfgTests))
}
} else {
if err == nil {
t.Fatalf("expected error (case %d/%d)", idx, len(CfgTests))
}
if ctx != nil {
t.Fatalf("expected nil ctx (case %d/%d)", idx, len(CfgTests))
}
}
}
}

View file

@ -2,6 +2,9 @@ package database
import (
"fmt"
"time"
"github.com/pkg/errors"
"github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus"
@ -23,9 +26,124 @@ func (c *Context) DeleteBan(target string) (int, error) {
func (c *Context) DeleteAll() error {
allBa := types.BanApplication{}
records := c.Db.Delete(&allBa)
records := c.Db.Unscoped().Delete(&allBa)
if records.Error != nil {
return records.Error
}
return nil
}
func (c *Context) DeleteExpired() (int, error) {
c.lock.Lock()
defer c.lock.Unlock()
//Delete the expired records
now := time.Now()
count := 0
if c.flush {
retx := c.Db.Delete(types.BanApplication{}, "until < ?", now)
if retx.Error != nil {
return 0, retx.Error
}
if retx.RowsAffected > 0 {
log.Infof("Flushed %d expired entries from Ban Application", retx.RowsAffected)
count = int(retx.RowsAffected)
}
}
return count, nil
}
func (c *Context) CleanUpRecordsByAge() (int, error) {
//let's fetch all expired records that are more than XX days olds
sos := []types.BanApplication{}
if c.maxDurationRetention == 0 {
return 0, nil
}
//look for soft-deleted events that are OLDER than maxDurationRetention
ret := c.Db.Unscoped().Table("ban_applications").Where("deleted_at is not NULL").
Where("until < ?", time.Now().Add(-c.maxDurationRetention)).
Order("updated_at desc").Find(&sos)
if ret.Error != nil {
return 0, errors.Wrap(ret.Error, "failed to get count of old records")
}
//no events elligible
if len(sos) == 0 || ret.RowsAffected == 0 {
log.Debugf("no event older than %s", c.maxDurationRetention.String())
return 0, nil
}
/*This is clearly suboptimal, and 'left join' and stuff gives way better results, but doesn't seem to behave equally on sqlite and mysql*/
delRecords := 0
for _, record := range sos {
copy := record
if ret := c.Db.Unscoped().Table("signal_occurences").Where("ID = ?", copy.SignalOccurenceID).Delete(&types.SignalOccurence{}); ret.Error != nil {
return 0, errors.Wrap(ret.Error, "failed to clean signal_occurences")
}
if ret := c.Db.Unscoped().Table("event_sequences").Where("signal_occurence_id = ?", copy.SignalOccurenceID).Delete(&types.EventSequence{}); ret.Error != nil {
return 0, errors.Wrap(ret.Error, "failed to clean event_sequences")
}
if ret := c.Db.Unscoped().Table("ban_applications").Delete(&copy); ret.Error != nil {
return 0, errors.Wrap(ret.Error, "failed to clean ban_applications")
}
delRecords++
}
log.Printf("max_records_age: deleting %d events (max age:%s)", delRecords, c.maxDurationRetention)
return delRecords, nil
}
func (c *Context) CleanUpRecordsByCount() (int, error) {
var count int
if c.maxEventRetention <= 0 {
return 0, nil
}
ret := c.Db.Unscoped().Table("ban_applications").Count(&count)
if ret.Error != nil {
return 0, errors.Wrap(ret.Error, "failed to get bans count")
}
if count < c.maxEventRetention {
log.Debugf("%d < %d, don't cleanup", count, c.maxEventRetention)
return 0, nil
}
sos := []types.BanApplication{}
now := time.Now()
/*get soft deleted records oldest to youngest*/
//records := c.Db.Unscoped().Table("ban_applications").Where("deleted_at is not NULL").Where(`strftime("%s", deleted_at) < strftime("%s", "now")`).Find(&sos)
records := c.Db.Unscoped().Table("ban_applications").Where("deleted_at is not NULL").Where("deleted_at < ?", now).Find(&sos)
if records.Error != nil {
return 0, errors.Wrap(records.Error, "failed to list expired bans for flush")
}
//let's do it in a single transaction
delRecords := 0
for _, ld := range sos {
copy := ld
if ret := c.Db.Unscoped().Table("signal_occurences").Where("ID = ?", copy.SignalOccurenceID).Delete(&types.SignalOccurence{}); ret.Error != nil {
return 0, errors.Wrap(ret.Error, "failed to clean signal_occurences")
}
if ret := c.Db.Unscoped().Table("event_sequences").Where("signal_occurence_id = ?", copy.SignalOccurenceID).Delete(&types.EventSequence{}); ret.Error != nil {
return 0, errors.Wrap(ret.Error, "failed to clean event_sequences")
}
if ret := c.Db.Unscoped().Table("ban_applications").Delete(&copy); ret.Error != nil {
return 0, errors.Wrap(ret.Error, "failed to clean ban_applications")
}
//we need to delete associations : event_sequences, signal_occurences
delRecords++
//let's delete as well the associated event_sequence
if count-delRecords <= c.maxEventRetention {
break
}
}
if len(sos) > 0 {
log.Printf("max_records: deleting %d events. (%d soft-deleted)", delRecords, len(sos))
} else {
log.Debugf("didn't find any record to clean")
}
return delRecords, nil
}

346
pkg/database/delete_test.go Normal file
View file

@ -0,0 +1,346 @@
package database
import (
"fmt"
"testing"
"time"
)
func TestNoCleanUpParams(t *testing.T) {
validCfg := map[string]string{
"type": "sqlite",
"db_path": "./test.db",
"debug": "false",
"max_records": "0",
"max_records_age": "0s",
"flush": "true",
}
ctx, err := NewDatabase(validCfg)
if err != nil || ctx == nil {
t.Fatalf("failed to create simple sqlite")
}
if err := ctx.DeleteAll(); err != nil {
t.Fatalf("failed to flush existing bans")
}
freshRecordsCount := 12
for i := 0; i < freshRecordsCount; i++ {
//this one expires in the future
OldSignal := genSignalOccurence(fmt.Sprintf("2.2.2.%d", i))
OldSignal.BanApplications[0].Until = time.Now().Add(1 * time.Hour)
if err = ctx.WriteBanApplication(OldSignal.BanApplications[0]); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
}
bans, err := ctx.GetBansAt(time.Now())
if err != nil {
t.Fatalf("%s", err)
}
if len(bans) != freshRecordsCount {
t.Fatalf("expected %d, got %d", freshRecordsCount, len(bans))
}
//Cleanup by age should hard delete old records
deleted, err := ctx.CleanUpRecordsByCount()
if err != nil {
t.Fatalf("error %s", err)
}
if deleted != 0 {
t.Fatalf("unexpected %d deleted events", deleted)
}
//Cleanup by age should hard delete old records
deleted, err = ctx.CleanUpRecordsByAge()
if err != nil {
t.Fatalf("error %s", err)
}
if deleted != 0 {
t.Fatalf("unexpected %d deleted events ", deleted)
}
}
func TestNoCleanUp(t *testing.T) {
validCfg := map[string]string{
"type": "sqlite",
"db_path": "./test.db",
"debug": "false",
"max_records": "1000",
"max_records_age": "24h",
"flush": "true",
}
ctx, err := NewDatabase(validCfg)
if err != nil || ctx == nil {
t.Fatalf("failed to create simple sqlite")
}
if err := ctx.DeleteAll(); err != nil {
t.Fatalf("failed to flush existing bans")
}
freshRecordsCount := 12
for i := 0; i < freshRecordsCount; i++ {
//this one expires in the future
OldSignal := genSignalOccurence(fmt.Sprintf("2.2.2.%d", i))
OldSignal.BanApplications[0].Until = time.Now().Add(1 * time.Hour)
if err = ctx.WriteBanApplication(OldSignal.BanApplications[0]); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
}
bans, err := ctx.GetBansAt(time.Now())
if err != nil {
t.Fatalf("%s", err)
}
if len(bans) != freshRecordsCount {
t.Fatalf("expected %d, got %d", freshRecordsCount, len(bans))
}
//Cleanup by age should hard delete old records
deleted, err := ctx.CleanUpRecordsByCount()
if err != nil {
t.Fatalf("error %s", err)
}
if deleted != 0 {
t.Fatalf("unexpected %d deleted events", deleted)
}
//Cleanup by age should hard delete old records
deleted, err = ctx.CleanUpRecordsByAge()
if err != nil {
t.Fatalf("error %s", err)
}
if deleted != 0 {
t.Fatalf("unexpected %d deleted events ", deleted)
}
}
func TestCleanUpByCount(t *testing.T) {
//plan :
// - insert one current event
// - insert 150 old events
// - check DeletedExpired behavior
// - check CleanUpByCount behavior
maxCount := 72
validCfg := map[string]string{
"type": "sqlite",
"db_path": "./test.db",
//that's 15 days
"max_records": fmt.Sprintf("%d", maxCount),
"debug": "false",
"flush": "true",
}
ctx, err := NewDatabase(validCfg)
if err != nil || ctx == nil {
t.Fatalf("failed to create simple sqlite")
}
if err := ctx.DeleteAll(); err != nil {
t.Fatalf("failed to flush existing bans")
}
freshRecordsCount := 12
for i := 0; i < freshRecordsCount; i++ {
//this one expires in the future
OldSignal := genSignalOccurence(fmt.Sprintf("2.2.2.%d", i))
OldSignal.BanApplications[0].Until = time.Now().Add(1 * time.Hour)
if err = ctx.WriteSignal(OldSignal); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
}
oldRecordsCount := 136
for i := 0; i < oldRecordsCount; i++ {
OldSignal := genSignalOccurence(fmt.Sprintf("1.2.3.%d", i))
//let's make the event a month old
OldSignal.Start_at = time.Now().Add(-30 * 24 * time.Hour)
OldSignal.Stop_at = time.Now().Add(-30 * 24 * time.Hour)
//ban was like for an hour
OldSignal.BanApplications[0].Until = time.Now().Add(-30*24*time.Hour + 1*time.Hour)
//write the old signal
if err = ctx.WriteSignal(OldSignal); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
}
evtsCount := 0
ret := ctx.Db.Unscoped().Table("ban_applications").Count(&evtsCount)
if ret.Error != nil {
t.Fatalf("got err : %s", ret.Error)
}
if evtsCount != oldRecordsCount+freshRecordsCount {
t.Fatalf("got %d events", evtsCount)
}
//if we call DeleteExpired, it will soft deleted those events in the past
softDeleted, err := ctx.DeleteExpired()
if err != nil {
t.Fatalf("%s", err)
}
if softDeleted != oldRecordsCount {
t.Fatalf("%d deleted records", softDeleted)
}
//we should be left with *one* non-deleted record
evtsCount = 0
ret = ctx.Db.Table("ban_applications").Where("deleted_at is NULL").Count(&evtsCount)
if ret.Error != nil {
t.Fatalf("got err : %s", ret.Error)
}
if evtsCount != freshRecordsCount {
t.Fatalf("got %d events", evtsCount)
}
evtsCount = 0
ret = ctx.Db.Table("ban_applications").Where("deleted_at is not NULL").Count(&evtsCount)
if ret.Error != nil {
t.Fatalf("got err : %s", ret.Error)
}
if evtsCount != oldRecordsCount {
t.Fatalf("got %d events", evtsCount)
}
//ctx.Db.LogMode(true)
//Cleanup by age should hard delete old records
deleted, err := ctx.CleanUpRecordsByCount()
if err != nil {
t.Fatalf("error %s", err)
}
if deleted != (oldRecordsCount+freshRecordsCount)-maxCount {
t.Fatalf("unexpected %d deleted events (expected: %d)", deleted, oldRecordsCount-maxCount)
}
//and now we should have *one* record left !
evtsCount = 0
ret = ctx.Db.Unscoped().Table("ban_applications").Count(&evtsCount)
if ret.Error != nil {
t.Fatalf("got err : %s", ret.Error)
}
if evtsCount != maxCount {
t.Fatalf("got %d events", evtsCount)
}
}
func TestCleanUpByAge(t *testing.T) {
//plan :
// - insert one current event
// - insert 150 old events
// - check DeletedExpired behavior
// - check CleanUpByAge behavior
validCfg := map[string]string{
"type": "sqlite",
"db_path": "./test.db",
//that's 15 days
"max_records_age": "360h",
"debug": "false",
"flush": "true",
}
ctx, err := NewDatabase(validCfg)
if err != nil || ctx == nil {
t.Fatalf("failed to create simple sqlite")
}
if err := ctx.DeleteAll(); err != nil {
t.Fatalf("failed to flush existing bans")
}
freshRecordsCount := 8
for i := 0; i < freshRecordsCount; i++ {
//this one expires in the future
OldSignal := genSignalOccurence(fmt.Sprintf("2.2.2.%d", i))
OldSignal.BanApplications[0].Until = time.Now().Add(1 * time.Hour)
if err = ctx.WriteSignal(OldSignal); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
}
oldRecordsCount := 150
for i := 0; i < oldRecordsCount; i++ {
OldSignal := genSignalOccurence(fmt.Sprintf("1.2.3.%d", i))
//let's make the event a month old
OldSignal.Start_at = time.Now().Add(-30 * 24 * time.Hour)
OldSignal.Stop_at = time.Now().Add(-30 * 24 * time.Hour)
//ban was like for an hour
OldSignal.BanApplications[0].Until = time.Now().Add(-30*24*time.Hour + 1*time.Hour)
//write the old signal
if err = ctx.WriteSignal(OldSignal); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
}
evtsCount := 0
ret := ctx.Db.Unscoped().Table("ban_applications").Count(&evtsCount)
if ret.Error != nil {
t.Fatalf("got err : %s", ret.Error)
}
if evtsCount != oldRecordsCount+freshRecordsCount {
t.Fatalf("got %d events", evtsCount)
}
//if we call DeleteExpired, it will soft deleted those events in the past
softDeleted, err := ctx.DeleteExpired()
if err != nil {
t.Fatalf("%s", err)
}
if softDeleted != oldRecordsCount {
t.Fatalf("%d deleted records", softDeleted)
}
//we should be left with *one* non-deleted record
evtsCount = 0
ret = ctx.Db.Table("ban_applications").Where("deleted_at is NULL").Count(&evtsCount)
if ret.Error != nil {
t.Fatalf("got err : %s", ret.Error)
}
if evtsCount != freshRecordsCount {
t.Fatalf("got %d events", evtsCount)
}
evtsCount = 0
ret = ctx.Db.Table("ban_applications").Where("deleted_at is not NULL").Count(&evtsCount)
if ret.Error != nil {
t.Fatalf("got err : %s", ret.Error)
}
if evtsCount != oldRecordsCount {
t.Fatalf("got %d events", evtsCount)
}
//Cleanup by age should hard delete old records
deleted, err := ctx.CleanUpRecordsByAge()
if err != nil {
t.Fatal()
}
if deleted != oldRecordsCount {
t.Fatalf("unexpected %d deleted events", deleted)
}
//and now we should have *one* record left !
evtsCount = 0
ret = ctx.Db.Unscoped().Table("ban_applications").Count(&evtsCount)
if ret.Error != nil {
t.Fatalf("got err : %s", ret.Error)
}
if evtsCount != freshRecordsCount {
t.Fatalf("got %d events", evtsCount)
}
}

View file

@ -15,9 +15,7 @@ func (c *Context) GetBansAt(at time.Time) ([]map[string]string, error) {
bas := []types.BanApplication{}
rets := make([]map[string]string, 0)
/*get non-expired records*/
//c.Db.LogMode(true)
//records := c.Db.Order("updated_at desc").Where(`strftime("%s", until) >= strftime("%s", ?) AND strftime("%s", created_at) < strftime("%s", ?)`, at, at).Group("ip_text").Find(&bas) /*.Count(&count)*/
records := c.Db.Order("updated_at desc").Where("until >= ? AND created_at < ?", at, at).Group("ip_text").Find(&bas) /*.Count(&count)*/
records := c.Db.Order("updated_at desc").Where("until >= ?", at).Group("ip_text").Find(&bas) /*.Count(&count)*/
if records.Error != nil {
return nil, records.Error
}
@ -26,8 +24,7 @@ func (c *Context) GetBansAt(at time.Time) ([]map[string]string, error) {
/*
fetch count of bans for this specific ip_text
*/
//ret := c.Db.Table("ban_applications").Order("updated_at desc").Where(`ip_text = ? AND strftime("%s", until) >= strftime("%s", ?) AND strftime("%s", created_at) < strftime("%s", ?) AND deleted_at is NULL`, ba.IpText, at, at).Count(&count)
ret := c.Db.Table("ban_applications").Order("updated_at desc").Where(`ip_text = ? AND until >= ? AND created_at < ? AND deleted_at is NULL`, ba.IpText, at, at).Count(&count)
ret := c.Db.Table("ban_applications").Order("updated_at desc").Where(`ip_text = ? AND until >= ? AND deleted_at is NULL`, ba.IpText, at).Count(&count)
if ret.Error != nil {
return nil, fmt.Errorf("failed to fetch records count for %s : %v", ba.IpText, ret.Error)
}

123
pkg/database/read_test.go Normal file
View file

@ -0,0 +1,123 @@
package database
import (
"testing"
"time"
)
func TestFetchBans(t *testing.T) {
//Plan:
// - flush db
// - write signal+ban for 1.2.3.4
// - get bans (as a connector) + check
// - write signal+ban for 1.2.3.5
// - get new bans (as a connector) + check
// - delete ban for 1.2.3.4
// - get deleted bans (as a connector) + check
validCfg := map[string]string{
"type": "sqlite",
"db_path": "./test.db",
"max_records_age": "72h",
"debug": "false",
"flush": "true",
}
ctx, err := NewDatabase(validCfg)
if err != nil || ctx == nil {
t.Fatalf("failed to create simple sqlite")
}
if err := ctx.DeleteAll(); err != nil {
t.Fatalf("failed to flush existing bans")
}
OldSignal := genSignalOccurence("1.2.3.4")
//write the old signal
if err = ctx.WriteSignal(OldSignal); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
//we startup, we should get one ban
firstFetch := time.Now()
bans, err := ctx.GetNewBan()
if err != nil {
t.Fatalf("%s", err)
}
if len(bans) != 1 {
t.Fatalf("expected one ban")
}
NewSignal := genSignalOccurence("1.2.3.5")
//write the old signal
if err = ctx.WriteSignal(NewSignal); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
//we startup, we should get one ban
bans, err = ctx.GetNewBanSince(firstFetch)
if err != nil {
t.Fatalf("%s", err)
}
firstFetch = time.Now()
if len(bans) != 1 {
t.Fatal()
}
if bans[0].MeasureSource != NewSignal.BanApplications[0].MeasureSource {
t.Fatal()
}
if bans[0].MeasureType != NewSignal.BanApplications[0].MeasureType {
t.Fatal()
}
if bans[0].StartIp != NewSignal.BanApplications[0].StartIp {
t.Fatal()
}
if bans[0].EndIp != NewSignal.BanApplications[0].EndIp {
t.Fatal()
}
if bans[0].Reason != NewSignal.BanApplications[0].Reason {
t.Fatal()
}
//Delete a ban
count, err := ctx.DeleteBan("1.2.3.4")
if err != nil {
t.Fatal()
}
if count != 1 {
t.Fatal()
}
//we shouldn't have any new bans
bans, err = ctx.GetNewBanSince(firstFetch)
if err != nil {
t.Fatal()
}
if len(bans) != 0 {
t.Fatal()
}
// //GetDeletedBanSince adds one second to the timestamp. why ? I'm not sure
// time.Sleep(1 * time.Second)
//but we should get a deleted ban
bans, err = ctx.GetDeletedBanSince(firstFetch.Add(-2 * time.Second))
if err != nil {
t.Fatalf("%s", err)
}
if len(bans) != 1 {
t.Fatalf("got %d", len(bans))
}
//OldSignal
if bans[0].MeasureSource != OldSignal.BanApplications[0].MeasureSource {
t.Fatal()
}
if bans[0].MeasureType != OldSignal.BanApplications[0].MeasureType {
t.Fatal()
}
if bans[0].StartIp != OldSignal.BanApplications[0].StartIp {
t.Fatal()
}
if bans[0].EndIp != OldSignal.BanApplications[0].EndIp {
t.Fatal()
}
if bans[0].Reason != OldSignal.BanApplications[0].Reason {
t.Fatal()
}
}

386
pkg/database/write_test.go Normal file
View file

@ -0,0 +1,386 @@
package database
import (
"database/sql"
"database/sql/driver"
"fmt"
"net"
"reflect"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/jinzhu/gorm"
"github.com/onsi/ginkgo"
"github.com/onsi/gomega"
)
type AnyTime struct{}
// Match satisfies sqlmock.Argument interface
func (a AnyTime) Match(v driver.Value) bool {
_, ok := v.(time.Time)
return ok
}
var _ = ginkgo.Describe("TestWrites", func() {
var ctx *Context
var mock sqlmock.Sqlmock
ginkgo.BeforeEach(func() {
var db *sql.DB
var err error
db, mock, err = sqlmock.New() // mock sql.DB
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
gdb, err := gorm.Open("sqlite", db) // open gorm db
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
ctx = &Context{Db: gdb}
//ctx.Db.LogMode(true)
})
ginkgo.AfterEach(func() {
err := mock.ExpectationsWereMet() // make sure all expectations were met
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
})
ginkgo.Context("insert ban_applications", func() {
ginkgo.It("insert 1.2.3.4", func() {
const sqlSelectAll = `SELECT * FROM "ban_applications" WHERE "ban_applications"."deleted_at" IS NULL AND (("ban_applications"."ip_text" = ?)) ORDER BY "ban_applications"."id" ASC LIMIT 1`
insertBan := types.BanApplication{IpText: "1.2.3.4"}
mock.ExpectQuery(regexp.QuoteMeta(sqlSelectAll)).WithArgs("1.2.3.4").WillReturnRows(sqlmock.NewRows(nil))
mock.ExpectBegin()
const sqlInsertBanApplication = `INSERT INTO "ban_applications" ("created_at","updated_at","deleted_at","measure_source","measure_type","measure_extra","until","start_ip","end_ip","target_cn","target_as","target_as_name","ip_text","reason","scenario","signal_occurence_id") VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`
InsertExpectedResult := sqlmock.NewResult(1, 1)
mock.ExpectExec(regexp.QuoteMeta(sqlInsertBanApplication)).WithArgs(
AnyTime{},
AnyTime{},
nil,
insertBan.MeasureSource,
insertBan.MeasureType,
insertBan.MeasureExtra,
AnyTime{},
insertBan.StartIp,
insertBan.EndIp,
insertBan.TargetCN,
insertBan.TargetAS,
insertBan.TargetASName,
insertBan.IpText,
insertBan.Reason,
insertBan.Scenario,
insertBan.SignalOccurenceID).WillReturnResult(InsertExpectedResult)
mock.ExpectCommit()
err := ctx.WriteBanApplication(insertBan)
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
})
})
ginkgo.Context("insert signal_occurence", func() {
ginkgo.It("insert signal+ban for 1.2.3.4", func() {
insertBan := types.BanApplication{IpText: "1.2.3.4", SignalOccurenceID: 1}
insertSig := types.SignalOccurence{
MapKey: "ratata",
Scenario: "test_1",
BanApplications: []types.BanApplication{insertBan},
Source_ip: "1.2.3.4",
Source_range: "1.2.3.0/24",
Source_AutonomousSystemNumber: "1234",
}
//the part that try to delete pending existing bans
mock.ExpectBegin()
const sqlDeleteOldBan = `UPDATE "ban_applications" SET "deleted_at"=? WHERE "ban_applications"."deleted_at" IS NULL AND ((ip_text = ?))`
sqlDeleteOldBanResult := sqlmock.NewResult(1, 1)
mock.ExpectExec(regexp.QuoteMeta(sqlDeleteOldBan)).WithArgs(AnyTime{}, "1.2.3.4").WillReturnResult(sqlDeleteOldBanResult)
mock.ExpectCommit()
//insert the signal occurence
mock.ExpectBegin()
const sqlInsertNewEvent = `INSERT INTO "signal_occurences" ("created_at","updated_at","deleted_at","map_key","scenario","bucket_id","alert_message","events_count","start_at","stop_at","source_ip","source_range","source_autonomous_system_number","source_autonomous_system_organization","source_country","source_latitude","source_longitude","dest_ip","capacity","leak_speed","reprocess") VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`
sqlInsertNewEventResult := sqlmock.NewResult(1, 1)
mock.ExpectExec(regexp.QuoteMeta(sqlInsertNewEvent)).WithArgs(
AnyTime{},
AnyTime{},
nil,
insertSig.MapKey,
insertSig.Scenario,
"",
"",
0,
AnyTime{},
AnyTime{},
insertSig.Source_ip,
insertSig.Source_range,
insertSig.Source_AutonomousSystemNumber,
"",
"",
0.0,
0.0,
"",
0,
0,
false,
).WillReturnResult(sqlInsertNewEventResult)
//insert the ban application
const sqlInsertBanApplication = `INSERT INTO "ban_applications" ("created_at","updated_at","deleted_at","measure_source","measure_type","measure_extra","until","start_ip","end_ip","target_cn","target_as","target_as_name","ip_text","reason","scenario","signal_occurence_id") VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`
sqlInsertBanApplicationResults := sqlmock.NewResult(1, 1)
mock.ExpectExec(regexp.QuoteMeta(sqlInsertBanApplication)).WithArgs(
AnyTime{},
AnyTime{},
nil,
insertBan.MeasureSource,
insertBan.MeasureType,
insertBan.MeasureExtra,
AnyTime{},
insertBan.StartIp,
insertBan.EndIp,
insertBan.TargetCN,
insertBan.TargetAS,
insertBan.TargetASName,
insertBan.IpText,
insertBan.Reason,
insertBan.Scenario,
insertBan.SignalOccurenceID).WillReturnResult(sqlInsertBanApplicationResults)
mock.ExpectCommit()
err := ctx.WriteSignal(insertSig)
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
})
})
ginkgo.Context("insert old signal_occurence + cleanup", func() {
ginkgo.It("insert signal+ban for 1.2.3.4", func() {
target_ip := net.ParseIP("1.2.3.4")
OldBan := types.BanApplication{
MeasureType: "ban",
MeasureSource: "local",
//expired one month ago
Until: time.Now().Add(-24 * 30 * time.Hour),
StartIp: types.IP2Int(target_ip),
EndIp: types.IP2Int(target_ip),
TargetCN: "FR",
TargetAS: 1234,
TargetASName: "Random AS",
IpText: target_ip.String(),
Reason: "A reason",
Scenario: "A scenario",
}
OldSignal := types.SignalOccurence{
MapKey: "lala",
Scenario: "old_overflow",
//two month ago : 24*60
Start_at: time.Now().Add(-24 * 60 * time.Hour),
Stop_at: time.Now().Add(-24 * 60 * time.Hour),
BanApplications: []types.BanApplication{OldBan},
}
//the part that try to delete pending existing bans
mock.ExpectBegin()
const sqlDeleteOldBan = `UPDATE "ban_applications" SET "deleted_at"=? WHERE "ban_applications"."deleted_at" IS NULL AND ((ip_text = ?))`
sqlDeleteOldBanResult := sqlmock.NewResult(1, 1)
mock.ExpectExec(regexp.QuoteMeta(sqlDeleteOldBan)).WithArgs(AnyTime{}, target_ip.String()).WillReturnResult(sqlDeleteOldBanResult)
mock.ExpectCommit()
//insert the signal occurence
mock.ExpectBegin()
const sqlInsertNewEvent = `INSERT INTO "signal_occurences" ("created_at","updated_at","deleted_at","map_key","scenario","bucket_id","alert_message","events_count","start_at","stop_at","source_ip","source_range","source_autonomous_system_number","source_autonomous_system_organization","source_country","source_latitude","source_longitude","dest_ip","capacity","leak_speed","reprocess") VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`
sqlInsertNewEventResult := sqlmock.NewResult(1, 1)
mock.ExpectExec(regexp.QuoteMeta(sqlInsertNewEvent)).WithArgs(
AnyTime{},
AnyTime{},
nil,
OldSignal.MapKey,
OldSignal.Scenario,
"",
"",
0,
AnyTime{},
AnyTime{},
OldSignal.Source_ip,
OldSignal.Source_range,
OldSignal.Source_AutonomousSystemNumber,
"",
"",
0.0,
0.0,
"",
0,
0,
false,
).WillReturnResult(sqlInsertNewEventResult)
//insert the ban application
const sqlInsertBanApplication = `INSERT INTO "ban_applications" ("created_at","updated_at","deleted_at","measure_source","measure_type","measure_extra","until","start_ip","end_ip","target_cn","target_as","target_as_name","ip_text","reason","scenario","signal_occurence_id") VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`
sqlInsertBanApplicationResults := sqlmock.NewResult(1, 1)
mock.ExpectExec(regexp.QuoteMeta(sqlInsertBanApplication)).WithArgs(
AnyTime{},
AnyTime{},
nil,
OldBan.MeasureSource,
OldBan.MeasureType,
OldBan.MeasureExtra,
AnyTime{},
OldBan.StartIp,
OldBan.EndIp,
OldBan.TargetCN,
OldBan.TargetAS,
OldBan.TargetASName,
OldBan.IpText,
OldBan.Reason,
OldBan.Scenario,
1).WillReturnResult(sqlInsertBanApplicationResults)
mock.ExpectCommit()
err := ctx.WriteSignal(OldSignal)
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
})
})
})
func TestInsertSqlMock(t *testing.T) {
gomega.RegisterFailHandler(ginkgo.Fail)
ginkgo.RunSpecs(t, "TestWrites")
}
func TestInsertOldBans(t *testing.T) {
//Plan:
// - flush db
// - insert month old ban
// - use GetBansAt on current + past time and check results
// - @todo : we need to call the DeleteExpired and such
validCfg := map[string]string{
"type": "sqlite",
"db_path": "./test.db",
"max_records": "1000",
"max_records_age": "72h",
"debug": "false",
"flush": "true",
}
ctx, err := NewDatabase(validCfg)
if err != nil || ctx == nil {
t.Fatalf("failed to create simple sqlite")
}
if err := ctx.DeleteAll(); err != nil {
t.Fatalf("failed to flush existing bans")
}
target_ip := net.ParseIP("1.2.3.4")
OldBan := types.BanApplication{
MeasureType: "ban",
MeasureSource: "local",
//expired one month ago
Until: time.Now().Add(-24 * 30 * time.Hour),
StartIp: types.IP2Int(target_ip),
EndIp: types.IP2Int(target_ip),
TargetCN: "FR",
TargetAS: 1234,
TargetASName: "Random AS",
IpText: target_ip.String(),
Reason: "A reason",
Scenario: "A scenario",
}
OldSignal := types.SignalOccurence{
MapKey: "lala",
Scenario: "old_overflow",
//two month ago : 24*60
Start_at: time.Now().Add(-24 * 60 * time.Hour),
Stop_at: time.Now().Add(-24 * 60 * time.Hour),
BanApplications: []types.BanApplication{OldBan},
}
//write the old signal
err = ctx.WriteSignal(OldSignal)
if err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
//fetch bans at current time
bans, err := ctx.GetBansAt(time.Now())
if err != nil {
t.Fatalf("failed to get bans : %s", err)
}
if len(bans) != 0 {
t.Fatalf("should not have bans, got %d bans", len(bans))
}
//get bans in the past
bans, err = ctx.GetBansAt(time.Now().Add(-24 * 31 * time.Hour))
if err != nil {
t.Fatalf("failed to get bans : %s", err)
}
if len(bans) != 1 {
t.Fatalf("should had 1 ban, got %d bans", len(bans))
}
if !reflect.DeepEqual(bans, []map[string]string{map[string]string{
"source": "local",
"until": "-720h0m0s",
"reason": "old_overflow",
"iptext": "1.2.3.4",
"cn": "",
"events_count": "0",
"action": "ban",
"as": " ",
"bancount": "0",
"scenario": "old_overflow",
}}) {
t.Fatalf("unexpected results")
}
}
func TestWriteBanApplicationOnly(t *testing.T) {
validCfg := map[string]string{
"type": "sqlite",
"db_path": "./test.db",
"debug": "false",
"flush": "true",
}
ctx, err := NewDatabase(validCfg)
if err != nil || ctx == nil {
t.Fatalf("failed to create simple sqlite")
}
if err := ctx.DeleteAll(); err != nil {
t.Fatalf("failed to flush existing bans")
}
freshRecordsCount := 12
for i := 0; i < freshRecordsCount; i++ {
//this one expires in the future
OldSignal := genSignalOccurence(fmt.Sprintf("2.2.2.%d", i))
OldSignal.BanApplications[0].Until = time.Now().Add(1 * time.Hour)
if err = ctx.WriteBanApplication(OldSignal.BanApplications[0]); err != nil {
t.Fatalf("Failed to insert old signal : %s", err)
}
}
bans, err := ctx.GetBansAt(time.Now())
if err != nil {
t.Fatalf("%s", err)
}
if len(bans) != freshRecordsCount {
t.Fatalf("expected %d, got %d", freshRecordsCount, len(bans))
}
}