Search: Make special character escaping compatible with SQLite #1994

This commit is contained in:
Michael Mayer 2022-03-28 17:36:59 +02:00
parent e693fad8dc
commit 9e46a66f24
7 changed files with 124 additions and 69 deletions

View file

@ -10,6 +10,11 @@ import (
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
) )
// Like escapes a string for use in a query.
func Like(s string) string {
return strings.Trim(sanitize.SqlString(s), " |&*%")
}
// LikeAny returns a single where condition matching the search words. // LikeAny returns a single where condition matching the search words.
func LikeAny(col, s string, keywords, exact bool) (wheres []string) { func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
if s == "" { if s == "" {
@ -44,9 +49,9 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
for _, w := range words { for _, w := range words {
if wildcardThreshold > 0 && len(w) >= wildcardThreshold { if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w))) orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w)))
} else { } else {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w))) orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w)))
} }
if !keywords || !txt.ContainsASCIILetters(w) { if !keywords || !txt.ContainsASCIILetters(w) {
@ -56,7 +61,7 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
singular := inflection.Singular(w) singular := inflection.Singular(w)
if singular != w { if singular != w {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(singular))) orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(singular)))
} }
} }
@ -103,9 +108,9 @@ func LikeAll(col, s string, keywords, exact bool) (wheres []string) {
for _, w := range words { for _, w := range words {
if wildcardThreshold > 0 && len(w) >= wildcardThreshold { if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w))) wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w)))
} else { } else {
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w))) wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w)))
} }
} }
@ -140,9 +145,9 @@ func LikeAllNames(cols Cols, s string) (wheres []string) {
for _, c := range cols { for _, c := range cols {
if strings.Contains(w, txt.Space) { if strings.Contains(w, txt.Space) {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, SqlLike(w))) orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, Like(w)))
} else { } else {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, SqlLike(w))) orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, Like(w)))
} }
} }
} }
@ -189,7 +194,7 @@ func AnySlug(col, search, sep string) (where string) {
} }
for _, w := range words { for _, w := range words {
wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, SqlLike(w))) wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, Like(w)))
} }
return strings.Join(wheres, " OR ") return strings.Join(wheres, " OR ")

View file

@ -12,6 +12,24 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestLike(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
assert.Equal(t, "", Like(""))
})
t.Run("Special", func(t *testing.T) {
s := " ' \" \t \n %_''"
exp := "'' \"\" %_''''"
result := Like(s)
t.Logf("String..: %s", s)
t.Logf("Expected: %s", exp)
t.Logf("Result..: %s", result)
assert.Equal(t, exp, result)
})
t.Run("Alnum", func(t *testing.T) {
assert.Equal(t, "123ABCabc", Like(" 123ABCabc%* "))
})
}
func TestLikeAny(t *testing.T) { func TestLikeAny(t *testing.T) {
t.Run("and_or_search", func(t *testing.T) { t.Run("and_or_search", func(t *testing.T) {
if w := LikeAny("k.keyword", "table spoon & usa | img json", true, false); len(w) != 2 { if w := LikeAny("k.keyword", "table spoon & usa | img json", true, false); len(w) != 2 {
@ -119,11 +137,11 @@ func TestLikeAnyWord(t *testing.T) {
} }
}) })
t.Run("EscapeSql", func(t *testing.T) { t.Run("EscapeSql", func(t *testing.T) {
if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"usa"); len(w) != 2 { if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"us'a"); len(w) != 2 {
t.Fatalf("two where conditions expected: %#v", w) t.Fatalf("two where conditions expected: %#v", w)
} else { } else {
assert.Equal(t, "k.keyword LIKE 'spoon%' OR k.keyword LIKE 'table%'", w[0]) assert.Equal(t, "k.keyword LIKE 'spoon%' OR k.keyword LIKE 'table%'", w[0])
assert.Equal(t, "k.keyword LIKE '\\\"usa%'", w[1]) assert.Equal(t, "k.keyword LIKE '\"\"us''a%'", w[1])
} }
}) })
} }

View file

@ -99,8 +99,7 @@ func TestPhotosFilterAlbums(t *testing.T) {
} }
assert.Equal(t, len(photos), 0) assert.Equal(t, len(photos), 0)
}) })
// TODO should not throw error t.Run("AlbumsSingleQuote", func(t *testing.T) {
/*t.Run("albums middle '", func(t *testing.T) {
var f form.SearchPhotos var f form.SearchPhotos
f.Albums = "Father's Day" f.Albums = "Father's Day"
@ -113,7 +112,7 @@ func TestPhotosFilterAlbums(t *testing.T) {
} }
assert.Greater(t, len(photos), 0) assert.Greater(t, len(photos), 0)
})*/ })
t.Run("albums end '", func(t *testing.T) { t.Run("albums end '", func(t *testing.T) {
var f form.SearchPhotos var f form.SearchPhotos
@ -190,8 +189,8 @@ func TestPhotosFilterAlbums(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// TODO: Needs review, variable number of results.
// TODO: Needs review, variable number of results.
assert.GreaterOrEqual(t, len(photos), 0) assert.GreaterOrEqual(t, len(photos), 0)
}) })
t.Run("albums end |", func(t *testing.T) { t.Run("albums end |", func(t *testing.T) {
@ -340,8 +339,7 @@ func TestPhotosQueryAlbums(t *testing.T) {
} }
assert.Equal(t, len(photos), 0) assert.Equal(t, len(photos), 0)
}) })
//TODO should not throw error t.Run("AlbumsQuerySingleQuote", func(t *testing.T) {
/*t.Run("albums middle '", func(t *testing.T) {
var f form.SearchPhotos var f form.SearchPhotos
f.Query = "albums:\"Father's Day\"" f.Query = "albums:\"Father's Day\""
@ -354,7 +352,7 @@ func TestPhotosQueryAlbums(t *testing.T) {
} }
assert.Greater(t, len(photos), 0) assert.Greater(t, len(photos), 0)
})*/ })
t.Run("albums end '", func(t *testing.T) { t.Run("albums end '", func(t *testing.T) {
var f form.SearchPhotos var f form.SearchPhotos
@ -431,8 +429,8 @@ func TestPhotosQueryAlbums(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// TODO: Needs review, variable number of results.
// TODO: Needs review, variable number of results.
assert.GreaterOrEqual(t, len(photos), 0) assert.GreaterOrEqual(t, len(photos), 0)
}) })
t.Run("albums end |", func(t *testing.T) { t.Run("albums end |", func(t *testing.T) {

View file

@ -1,12 +0,0 @@
package search
import (
"strings"
"github.com/photoprism/photoprism/pkg/sanitize"
)
// SqlLike escapes a string for use in an SQL query.
func SqlLike(s string) string {
return strings.Trim(sanitize.SqlString(s), " |&*%")
}

View file

@ -1,25 +0,0 @@
package search
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSqlLike(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
assert.Equal(t, "", SqlLike(""))
})
t.Run("Special", func(t *testing.T) {
s := "' \" \t \n %_''"
exp := "\\' \\\" %\\_\\'\\'"
result := SqlLike(s)
t.Logf("String..: %s", s)
t.Logf("Expected: %s", exp)
t.Logf("Result..: %s", result)
assert.Equal(t, exp, result)
})
t.Run("Alnum", func(t *testing.T) {
assert.Equal(t, "123ABCabc", SqlLike(" 123ABCabc%* "))
})
}

View file

@ -1,41 +1,53 @@
package sanitize package sanitize
import ( // SqlSpecial checks if the byte must be escaped/omitted in SQL.
"bytes" func SqlSpecial(b byte) (special bool, omit bool) {
) if b < 32 {
return true, true
}
// sqlSpecialBytes contains special bytes to escape in SQL search queries. switch b {
// see https://mariadb.com/kb/en/string-literals/ case '"', '\'', '\\':
var sqlSpecialBytes = []byte{34, 39, 92, 95} // ", ', \, _ return true, false
default:
return false, false
}
}
// SqlString escapes a string for use in an SQL query. // SqlString escapes a string for use in an SQL query.
func SqlString(s string) string { func SqlString(s string) string {
var i int var i int
for i = 0; i < len(s); i++ { for i = 0; i < len(s); i++ {
if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) { if found, _ := SqlSpecial(s[i]); found {
break break
} }
} }
// No special characters found, return original string. // Return if no special characters were found.
if i >= len(s) { if i >= len(s) {
return s return s
} }
b := make([]byte, 2*len(s)-i) b := make([]byte, 2*len(s)-i)
copy(b, s[:i]) copy(b, s[:i])
j := i j := i
for ; i < len(s); i++ { for ; i < len(s); i++ {
if s[i] < 31 { if special, omit := SqlSpecial(s[i]); omit {
// Ignore control chars. // Omit control characters.
continue continue
} } else if special {
if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) { // Escape other special characters.
b[j] = '\\' // see https://mariadb.com/kb/en/string-literals/
j++
}
b[j] = s[i] b[j] = s[i]
j++ j++
} }
b[j] = s[i]
j++
}
return string(b[:j]) return string(b[:j])
} }

View file

@ -6,13 +6,72 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestSqlSpecial(t *testing.T) {
t.Run("Special", func(t *testing.T) {
if s, o := SqlSpecial(1); !s {
t.Error("char is special")
} else if !o {
t.Error("\" must be omitted")
}
if s, o := SqlSpecial(31); !s {
t.Error("char is special")
} else if !o {
t.Error("\" must be omitted")
}
if s, o := SqlSpecial('\\'); !s {
t.Error("\\ is special")
} else if o {
t.Error("\\ must not be omitted")
}
if s, o := SqlSpecial('\''); !s {
t.Error("' is special")
} else if o {
t.Error("' must not be omitted")
}
if s, o := SqlSpecial('"'); !s {
t.Error("\" is special")
} else if o {
t.Error("\" must not be omitted")
}
})
t.Run("NotSpecial", func(t *testing.T) {
if s, o := SqlSpecial(32); s {
t.Error("space is not special")
} else if o {
t.Error("space must not be omitted")
}
if s, o := SqlSpecial('A'); s {
t.Error("A is not special")
} else if o {
t.Error("A must not be omitted")
}
if s, o := SqlSpecial('a'); s {
t.Error("a is not special")
} else if o {
t.Error("a must not be omitted")
}
if s, o := SqlSpecial('_'); s {
t.Error("_ is not special")
} else if o {
t.Error("_ must not be omitted")
}
})
}
func TestSqlString(t *testing.T) { func TestSqlString(t *testing.T) {
t.Run("Empty", func(t *testing.T) { t.Run("Empty", func(t *testing.T) {
assert.Equal(t, "", SqlString("")) assert.Equal(t, "", SqlString(""))
}) })
t.Run("Special", func(t *testing.T) { t.Run("Special", func(t *testing.T) {
s := "' \" \t \n %_''" s := "' \" \t \n %_''"
exp := "\\' \\\" %\\_\\'\\'" exp := "'' \"\" %_''''"
result := SqlString(s) result := SqlString(s)
t.Logf("String..: %s", s) t.Logf("String..: %s", s)
t.Logf("Expected: %s", exp) t.Logf("Expected: %s", exp)