Search: Make special character escaping compatible with SQLite #1994
This commit is contained in:
parent
e693fad8dc
commit
9e46a66f24
|
@ -10,6 +10,11 @@ import (
|
||||||
"github.com/jinzhu/inflection"
|
"github.com/jinzhu/inflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Like escapes a string for use in a query.
|
||||||
|
func Like(s string) string {
|
||||||
|
return strings.Trim(sanitize.SqlString(s), " |&*%")
|
||||||
|
}
|
||||||
|
|
||||||
// LikeAny returns a single where condition matching the search words.
|
// LikeAny returns a single where condition matching the search words.
|
||||||
func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
|
func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
|
@ -44,9 +49,9 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
|
||||||
|
|
||||||
for _, w := range words {
|
for _, w := range words {
|
||||||
if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
|
if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
|
||||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w)))
|
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w)))
|
||||||
} else {
|
} else {
|
||||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w)))
|
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w)))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !keywords || !txt.ContainsASCIILetters(w) {
|
if !keywords || !txt.ContainsASCIILetters(w) {
|
||||||
|
@ -56,7 +61,7 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
|
||||||
singular := inflection.Singular(w)
|
singular := inflection.Singular(w)
|
||||||
|
|
||||||
if singular != w {
|
if singular != w {
|
||||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(singular)))
|
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(singular)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,9 +108,9 @@ func LikeAll(col, s string, keywords, exact bool) (wheres []string) {
|
||||||
|
|
||||||
for _, w := range words {
|
for _, w := range words {
|
||||||
if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
|
if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
|
||||||
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w)))
|
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w)))
|
||||||
} else {
|
} else {
|
||||||
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w)))
|
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,9 +145,9 @@ func LikeAllNames(cols Cols, s string) (wheres []string) {
|
||||||
|
|
||||||
for _, c := range cols {
|
for _, c := range cols {
|
||||||
if strings.Contains(w, txt.Space) {
|
if strings.Contains(w, txt.Space) {
|
||||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, SqlLike(w)))
|
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, Like(w)))
|
||||||
} else {
|
} else {
|
||||||
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, SqlLike(w)))
|
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, Like(w)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -189,7 +194,7 @@ func AnySlug(col, search, sep string) (where string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, w := range words {
|
for _, w := range words {
|
||||||
wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, SqlLike(w)))
|
wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, Like(w)))
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(wheres, " OR ")
|
return strings.Join(wheres, " OR ")
|
||||||
|
|
|
@ -12,6 +12,24 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestLike(t *testing.T) {
|
||||||
|
t.Run("Empty", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "", Like(""))
|
||||||
|
})
|
||||||
|
t.Run("Special", func(t *testing.T) {
|
||||||
|
s := " ' \" \t \n %_''"
|
||||||
|
exp := "'' \"\" %_''''"
|
||||||
|
result := Like(s)
|
||||||
|
t.Logf("String..: %s", s)
|
||||||
|
t.Logf("Expected: %s", exp)
|
||||||
|
t.Logf("Result..: %s", result)
|
||||||
|
assert.Equal(t, exp, result)
|
||||||
|
})
|
||||||
|
t.Run("Alnum", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "123ABCabc", Like(" 123ABCabc%* "))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestLikeAny(t *testing.T) {
|
func TestLikeAny(t *testing.T) {
|
||||||
t.Run("and_or_search", func(t *testing.T) {
|
t.Run("and_or_search", func(t *testing.T) {
|
||||||
if w := LikeAny("k.keyword", "table spoon & usa | img json", true, false); len(w) != 2 {
|
if w := LikeAny("k.keyword", "table spoon & usa | img json", true, false); len(w) != 2 {
|
||||||
|
@ -119,11 +137,11 @@ func TestLikeAnyWord(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("EscapeSql", func(t *testing.T) {
|
t.Run("EscapeSql", func(t *testing.T) {
|
||||||
if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"usa"); len(w) != 2 {
|
if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"us'a"); len(w) != 2 {
|
||||||
t.Fatalf("two where conditions expected: %#v", w)
|
t.Fatalf("two where conditions expected: %#v", w)
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, "k.keyword LIKE 'spoon%' OR k.keyword LIKE 'table%'", w[0])
|
assert.Equal(t, "k.keyword LIKE 'spoon%' OR k.keyword LIKE 'table%'", w[0])
|
||||||
assert.Equal(t, "k.keyword LIKE '\\\"usa%'", w[1])
|
assert.Equal(t, "k.keyword LIKE '\"\"us''a%'", w[1])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,8 +99,7 @@ func TestPhotosFilterAlbums(t *testing.T) {
|
||||||
}
|
}
|
||||||
assert.Equal(t, len(photos), 0)
|
assert.Equal(t, len(photos), 0)
|
||||||
})
|
})
|
||||||
// TODO should not throw error
|
t.Run("AlbumsSingleQuote", func(t *testing.T) {
|
||||||
/*t.Run("albums middle '", func(t *testing.T) {
|
|
||||||
var f form.SearchPhotos
|
var f form.SearchPhotos
|
||||||
|
|
||||||
f.Albums = "Father's Day"
|
f.Albums = "Father's Day"
|
||||||
|
@ -113,7 +112,7 @@ func TestPhotosFilterAlbums(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Greater(t, len(photos), 0)
|
assert.Greater(t, len(photos), 0)
|
||||||
})*/
|
})
|
||||||
t.Run("albums end '", func(t *testing.T) {
|
t.Run("albums end '", func(t *testing.T) {
|
||||||
var f form.SearchPhotos
|
var f form.SearchPhotos
|
||||||
|
|
||||||
|
@ -190,8 +189,8 @@ func TestPhotosFilterAlbums(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
// TODO: Needs review, variable number of results.
|
|
||||||
|
|
||||||
|
// TODO: Needs review, variable number of results.
|
||||||
assert.GreaterOrEqual(t, len(photos), 0)
|
assert.GreaterOrEqual(t, len(photos), 0)
|
||||||
})
|
})
|
||||||
t.Run("albums end |", func(t *testing.T) {
|
t.Run("albums end |", func(t *testing.T) {
|
||||||
|
@ -340,8 +339,7 @@ func TestPhotosQueryAlbums(t *testing.T) {
|
||||||
}
|
}
|
||||||
assert.Equal(t, len(photos), 0)
|
assert.Equal(t, len(photos), 0)
|
||||||
})
|
})
|
||||||
//TODO should not throw error
|
t.Run("AlbumsQuerySingleQuote", func(t *testing.T) {
|
||||||
/*t.Run("albums middle '", func(t *testing.T) {
|
|
||||||
var f form.SearchPhotos
|
var f form.SearchPhotos
|
||||||
|
|
||||||
f.Query = "albums:\"Father's Day\""
|
f.Query = "albums:\"Father's Day\""
|
||||||
|
@ -354,7 +352,7 @@ func TestPhotosQueryAlbums(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Greater(t, len(photos), 0)
|
assert.Greater(t, len(photos), 0)
|
||||||
})*/
|
})
|
||||||
t.Run("albums end '", func(t *testing.T) {
|
t.Run("albums end '", func(t *testing.T) {
|
||||||
var f form.SearchPhotos
|
var f form.SearchPhotos
|
||||||
|
|
||||||
|
@ -431,8 +429,8 @@ func TestPhotosQueryAlbums(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
// TODO: Needs review, variable number of results.
|
|
||||||
|
|
||||||
|
// TODO: Needs review, variable number of results.
|
||||||
assert.GreaterOrEqual(t, len(photos), 0)
|
assert.GreaterOrEqual(t, len(photos), 0)
|
||||||
})
|
})
|
||||||
t.Run("albums end |", func(t *testing.T) {
|
t.Run("albums end |", func(t *testing.T) {
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
package search
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/photoprism/photoprism/pkg/sanitize"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SqlLike escapes a string for use in an SQL query.
|
|
||||||
func SqlLike(s string) string {
|
|
||||||
return strings.Trim(sanitize.SqlString(s), " |&*%")
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
package search
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSqlLike(t *testing.T) {
|
|
||||||
t.Run("Empty", func(t *testing.T) {
|
|
||||||
assert.Equal(t, "", SqlLike(""))
|
|
||||||
})
|
|
||||||
t.Run("Special", func(t *testing.T) {
|
|
||||||
s := "' \" \t \n %_''"
|
|
||||||
exp := "\\' \\\" %\\_\\'\\'"
|
|
||||||
result := SqlLike(s)
|
|
||||||
t.Logf("String..: %s", s)
|
|
||||||
t.Logf("Expected: %s", exp)
|
|
||||||
t.Logf("Result..: %s", result)
|
|
||||||
assert.Equal(t, exp, result)
|
|
||||||
})
|
|
||||||
t.Run("Alnum", func(t *testing.T) {
|
|
||||||
assert.Equal(t, "123ABCabc", SqlLike(" 123ABCabc%* "))
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,41 +1,53 @@
|
||||||
package sanitize
|
package sanitize
|
||||||
|
|
||||||
import (
|
// SqlSpecial checks if the byte must be escaped/omitted in SQL.
|
||||||
"bytes"
|
func SqlSpecial(b byte) (special bool, omit bool) {
|
||||||
)
|
if b < 32 {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
|
||||||
// sqlSpecialBytes contains special bytes to escape in SQL search queries.
|
switch b {
|
||||||
// see https://mariadb.com/kb/en/string-literals/
|
case '"', '\'', '\\':
|
||||||
var sqlSpecialBytes = []byte{34, 39, 92, 95} // ", ', \, _
|
return true, false
|
||||||
|
default:
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SqlString escapes a string for use in an SQL query.
|
// SqlString escapes a string for use in an SQL query.
|
||||||
func SqlString(s string) string {
|
func SqlString(s string) string {
|
||||||
var i int
|
var i int
|
||||||
for i = 0; i < len(s); i++ {
|
for i = 0; i < len(s); i++ {
|
||||||
if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) {
|
if found, _ := SqlSpecial(s[i]); found {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// No special characters found, return original string.
|
// Return if no special characters were found.
|
||||||
if i >= len(s) {
|
if i >= len(s) {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
b := make([]byte, 2*len(s)-i)
|
b := make([]byte, 2*len(s)-i)
|
||||||
|
|
||||||
copy(b, s[:i])
|
copy(b, s[:i])
|
||||||
|
|
||||||
j := i
|
j := i
|
||||||
|
|
||||||
for ; i < len(s); i++ {
|
for ; i < len(s); i++ {
|
||||||
if s[i] < 31 {
|
if special, omit := SqlSpecial(s[i]); omit {
|
||||||
// Ignore control chars.
|
// Omit control characters.
|
||||||
continue
|
continue
|
||||||
}
|
} else if special {
|
||||||
if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) {
|
// Escape other special characters.
|
||||||
b[j] = '\\'
|
// see https://mariadb.com/kb/en/string-literals/
|
||||||
j++
|
|
||||||
}
|
|
||||||
b[j] = s[i]
|
b[j] = s[i]
|
||||||
j++
|
j++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b[j] = s[i]
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
|
||||||
return string(b[:j])
|
return string(b[:j])
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,13 +6,72 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestSqlSpecial(t *testing.T) {
|
||||||
|
t.Run("Special", func(t *testing.T) {
|
||||||
|
if s, o := SqlSpecial(1); !s {
|
||||||
|
t.Error("char is special")
|
||||||
|
} else if !o {
|
||||||
|
t.Error("\" must be omitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, o := SqlSpecial(31); !s {
|
||||||
|
t.Error("char is special")
|
||||||
|
} else if !o {
|
||||||
|
t.Error("\" must be omitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, o := SqlSpecial('\\'); !s {
|
||||||
|
t.Error("\\ is special")
|
||||||
|
} else if o {
|
||||||
|
t.Error("\\ must not be omitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, o := SqlSpecial('\''); !s {
|
||||||
|
t.Error("' is special")
|
||||||
|
} else if o {
|
||||||
|
t.Error("' must not be omitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, o := SqlSpecial('"'); !s {
|
||||||
|
t.Error("\" is special")
|
||||||
|
} else if o {
|
||||||
|
t.Error("\" must not be omitted")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("NotSpecial", func(t *testing.T) {
|
||||||
|
if s, o := SqlSpecial(32); s {
|
||||||
|
t.Error("space is not special")
|
||||||
|
} else if o {
|
||||||
|
t.Error("space must not be omitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, o := SqlSpecial('A'); s {
|
||||||
|
t.Error("A is not special")
|
||||||
|
} else if o {
|
||||||
|
t.Error("A must not be omitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, o := SqlSpecial('a'); s {
|
||||||
|
t.Error("a is not special")
|
||||||
|
} else if o {
|
||||||
|
t.Error("a must not be omitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, o := SqlSpecial('_'); s {
|
||||||
|
t.Error("_ is not special")
|
||||||
|
} else if o {
|
||||||
|
t.Error("_ must not be omitted")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqlString(t *testing.T) {
|
func TestSqlString(t *testing.T) {
|
||||||
t.Run("Empty", func(t *testing.T) {
|
t.Run("Empty", func(t *testing.T) {
|
||||||
assert.Equal(t, "", SqlString(""))
|
assert.Equal(t, "", SqlString(""))
|
||||||
})
|
})
|
||||||
t.Run("Special", func(t *testing.T) {
|
t.Run("Special", func(t *testing.T) {
|
||||||
s := "' \" \t \n %_''"
|
s := "' \" \t \n %_''"
|
||||||
exp := "\\' \\\" %\\_\\'\\'"
|
exp := "'' \"\" %_''''"
|
||||||
result := SqlString(s)
|
result := SqlString(s)
|
||||||
t.Logf("String..: %s", s)
|
t.Logf("String..: %s", s)
|
||||||
t.Logf("Expected: %s", exp)
|
t.Logf("Expected: %s", exp)
|
||||||
|
|
Loading…
Reference in a new issue