Search: Make special character escaping compatible with SQLite #1994
This commit is contained in:
parent
e693fad8dc
commit
9e46a66f24
|
@ -10,6 +10,11 @@ import (
|
|||
"github.com/jinzhu/inflection"
|
||||
)
|
||||
|
||||
// Like escapes a string for use in a query.
|
||||
func Like(s string) string {
|
||||
return strings.Trim(sanitize.SqlString(s), " |&*%")
|
||||
}
|
||||
|
||||
// LikeAny returns a single where condition matching the search words.
|
||||
func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
|
||||
if s == "" {
|
||||
|
@ -44,9 +49,9 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
|
|||
|
||||
for _, w := range words {
|
||||
if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w)))
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w)))
|
||||
} else {
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w)))
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w)))
|
||||
}
|
||||
|
||||
if !keywords || !txt.ContainsASCIILetters(w) {
|
||||
|
@ -56,7 +61,7 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
|
|||
singular := inflection.Singular(w)
|
||||
|
||||
if singular != w {
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(singular)))
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(singular)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,9 +108,9 @@ func LikeAll(col, s string, keywords, exact bool) (wheres []string) {
|
|||
|
||||
for _, w := range words {
|
||||
if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
|
||||
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w)))
|
||||
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w)))
|
||||
} else {
|
||||
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w)))
|
||||
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,9 +145,9 @@ func LikeAllNames(cols Cols, s string) (wheres []string) {
|
|||
|
||||
for _, c := range cols {
|
||||
if strings.Contains(w, txt.Space) {
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, SqlLike(w)))
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, Like(w)))
|
||||
} else {
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, SqlLike(w)))
|
||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, Like(w)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -189,7 +194,7 @@ func AnySlug(col, search, sep string) (where string) {
|
|||
}
|
||||
|
||||
for _, w := range words {
|
||||
wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, SqlLike(w)))
|
||||
wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, Like(w)))
|
||||
}
|
||||
|
||||
return strings.Join(wheres, " OR ")
|
||||
|
|
|
@ -12,6 +12,24 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLike(t *testing.T) {
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", Like(""))
|
||||
})
|
||||
t.Run("Special", func(t *testing.T) {
|
||||
s := " ' \" \t \n %_''"
|
||||
exp := "'' \"\" %_''''"
|
||||
result := Like(s)
|
||||
t.Logf("String..: %s", s)
|
||||
t.Logf("Expected: %s", exp)
|
||||
t.Logf("Result..: %s", result)
|
||||
assert.Equal(t, exp, result)
|
||||
})
|
||||
t.Run("Alnum", func(t *testing.T) {
|
||||
assert.Equal(t, "123ABCabc", Like(" 123ABCabc%* "))
|
||||
})
|
||||
}
|
||||
|
||||
func TestLikeAny(t *testing.T) {
|
||||
t.Run("and_or_search", func(t *testing.T) {
|
||||
if w := LikeAny("k.keyword", "table spoon & usa | img json", true, false); len(w) != 2 {
|
||||
|
@ -119,11 +137,11 @@ func TestLikeAnyWord(t *testing.T) {
|
|||
}
|
||||
})
|
||||
t.Run("EscapeSql", func(t *testing.T) {
|
||||
if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"usa"); len(w) != 2 {
|
||||
if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"us'a"); len(w) != 2 {
|
||||
t.Fatalf("two where conditions expected: %#v", w)
|
||||
} else {
|
||||
assert.Equal(t, "k.keyword LIKE 'spoon%' OR k.keyword LIKE 'table%'", w[0])
|
||||
assert.Equal(t, "k.keyword LIKE '\\\"usa%'", w[1])
|
||||
assert.Equal(t, "k.keyword LIKE '\"\"us''a%'", w[1])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -99,8 +99,7 @@ func TestPhotosFilterAlbums(t *testing.T) {
|
|||
}
|
||||
assert.Equal(t, len(photos), 0)
|
||||
})
|
||||
// TODO should not throw error
|
||||
/*t.Run("albums middle '", func(t *testing.T) {
|
||||
t.Run("AlbumsSingleQuote", func(t *testing.T) {
|
||||
var f form.SearchPhotos
|
||||
|
||||
f.Albums = "Father's Day"
|
||||
|
@ -113,7 +112,7 @@ func TestPhotosFilterAlbums(t *testing.T) {
|
|||
}
|
||||
|
||||
assert.Greater(t, len(photos), 0)
|
||||
})*/
|
||||
})
|
||||
t.Run("albums end '", func(t *testing.T) {
|
||||
var f form.SearchPhotos
|
||||
|
||||
|
@ -190,8 +189,8 @@ func TestPhotosFilterAlbums(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// TODO: Needs review, variable number of results.
|
||||
|
||||
// TODO: Needs review, variable number of results.
|
||||
assert.GreaterOrEqual(t, len(photos), 0)
|
||||
})
|
||||
t.Run("albums end |", func(t *testing.T) {
|
||||
|
@ -340,8 +339,7 @@ func TestPhotosQueryAlbums(t *testing.T) {
|
|||
}
|
||||
assert.Equal(t, len(photos), 0)
|
||||
})
|
||||
//TODO should not throw error
|
||||
/*t.Run("albums middle '", func(t *testing.T) {
|
||||
t.Run("AlbumsQuerySingleQuote", func(t *testing.T) {
|
||||
var f form.SearchPhotos
|
||||
|
||||
f.Query = "albums:\"Father's Day\""
|
||||
|
@ -354,7 +352,7 @@ func TestPhotosQueryAlbums(t *testing.T) {
|
|||
}
|
||||
|
||||
assert.Greater(t, len(photos), 0)
|
||||
})*/
|
||||
})
|
||||
t.Run("albums end '", func(t *testing.T) {
|
||||
var f form.SearchPhotos
|
||||
|
||||
|
@ -431,8 +429,8 @@ func TestPhotosQueryAlbums(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// TODO: Needs review, variable number of results.
|
||||
|
||||
// TODO: Needs review, variable number of results.
|
||||
assert.GreaterOrEqual(t, len(photos), 0)
|
||||
})
|
||||
t.Run("albums end |", func(t *testing.T) {
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/sanitize"
|
||||
)
|
||||
|
||||
// SqlLike escapes a string for use in an SQL query.
|
||||
func SqlLike(s string) string {
|
||||
return strings.Trim(sanitize.SqlString(s), " |&*%")
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSqlLike(t *testing.T) {
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", SqlLike(""))
|
||||
})
|
||||
t.Run("Special", func(t *testing.T) {
|
||||
s := "' \" \t \n %_''"
|
||||
exp := "\\' \\\" %\\_\\'\\'"
|
||||
result := SqlLike(s)
|
||||
t.Logf("String..: %s", s)
|
||||
t.Logf("Expected: %s", exp)
|
||||
t.Logf("Result..: %s", result)
|
||||
assert.Equal(t, exp, result)
|
||||
})
|
||||
t.Run("Alnum", func(t *testing.T) {
|
||||
assert.Equal(t, "123ABCabc", SqlLike(" 123ABCabc%* "))
|
||||
})
|
||||
}
|
|
@ -1,41 +1,53 @@
|
|||
package sanitize
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
// SqlSpecial checks if the byte must be escaped/omitted in SQL.
|
||||
func SqlSpecial(b byte) (special bool, omit bool) {
|
||||
if b < 32 {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// sqlSpecialBytes contains special bytes to escape in SQL search queries.
|
||||
// see https://mariadb.com/kb/en/string-literals/
|
||||
var sqlSpecialBytes = []byte{34, 39, 92, 95} // ", ', \, _
|
||||
switch b {
|
||||
case '"', '\'', '\\':
|
||||
return true, false
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
// SqlString escapes a string for use in an SQL query.
|
||||
func SqlString(s string) string {
|
||||
var i int
|
||||
for i = 0; i < len(s); i++ {
|
||||
if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) {
|
||||
if found, _ := SqlSpecial(s[i]); found {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// No special characters found, return original string.
|
||||
// Return if no special characters were found.
|
||||
if i >= len(s) {
|
||||
return s
|
||||
}
|
||||
|
||||
b := make([]byte, 2*len(s)-i)
|
||||
|
||||
copy(b, s[:i])
|
||||
|
||||
j := i
|
||||
|
||||
for ; i < len(s); i++ {
|
||||
if s[i] < 31 {
|
||||
// Ignore control chars.
|
||||
if special, omit := SqlSpecial(s[i]); omit {
|
||||
// Omit control characters.
|
||||
continue
|
||||
}
|
||||
if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) {
|
||||
b[j] = '\\'
|
||||
} else if special {
|
||||
// Escape other special characters.
|
||||
// see https://mariadb.com/kb/en/string-literals/
|
||||
b[j] = s[i]
|
||||
j++
|
||||
}
|
||||
|
||||
b[j] = s[i]
|
||||
j++
|
||||
}
|
||||
|
||||
return string(b[:j])
|
||||
}
|
||||
|
|
|
@ -6,13 +6,72 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSqlSpecial(t *testing.T) {
|
||||
t.Run("Special", func(t *testing.T) {
|
||||
if s, o := SqlSpecial(1); !s {
|
||||
t.Error("char is special")
|
||||
} else if !o {
|
||||
t.Error("\" must be omitted")
|
||||
}
|
||||
|
||||
if s, o := SqlSpecial(31); !s {
|
||||
t.Error("char is special")
|
||||
} else if !o {
|
||||
t.Error("\" must be omitted")
|
||||
}
|
||||
|
||||
if s, o := SqlSpecial('\\'); !s {
|
||||
t.Error("\\ is special")
|
||||
} else if o {
|
||||
t.Error("\\ must not be omitted")
|
||||
}
|
||||
|
||||
if s, o := SqlSpecial('\''); !s {
|
||||
t.Error("' is special")
|
||||
} else if o {
|
||||
t.Error("' must not be omitted")
|
||||
}
|
||||
|
||||
if s, o := SqlSpecial('"'); !s {
|
||||
t.Error("\" is special")
|
||||
} else if o {
|
||||
t.Error("\" must not be omitted")
|
||||
}
|
||||
})
|
||||
t.Run("NotSpecial", func(t *testing.T) {
|
||||
if s, o := SqlSpecial(32); s {
|
||||
t.Error("space is not special")
|
||||
} else if o {
|
||||
t.Error("space must not be omitted")
|
||||
}
|
||||
|
||||
if s, o := SqlSpecial('A'); s {
|
||||
t.Error("A is not special")
|
||||
} else if o {
|
||||
t.Error("A must not be omitted")
|
||||
}
|
||||
|
||||
if s, o := SqlSpecial('a'); s {
|
||||
t.Error("a is not special")
|
||||
} else if o {
|
||||
t.Error("a must not be omitted")
|
||||
}
|
||||
|
||||
if s, o := SqlSpecial('_'); s {
|
||||
t.Error("_ is not special")
|
||||
} else if o {
|
||||
t.Error("_ must not be omitted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlString(t *testing.T) {
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", SqlString(""))
|
||||
})
|
||||
t.Run("Special", func(t *testing.T) {
|
||||
s := "' \" \t \n %_''"
|
||||
exp := "\\' \\\" %\\_\\'\\'"
|
||||
exp := "'' \"\" %_''''"
|
||||
result := SqlString(s)
|
||||
t.Logf("String..: %s", s)
|
||||
t.Logf("Expected: %s", exp)
|
||||
|
|
Loading…
Reference in a new issue