diff --git a/internal/search/like.go b/internal/search/like.go index 2b17665ca..c906fd18e 100644 --- a/internal/search/like.go +++ b/internal/search/like.go @@ -10,6 +10,11 @@ import ( "github.com/jinzhu/inflection" ) +// Like escapes a string for use in a query. +func Like(s string) string { + return strings.Trim(sanitize.SqlString(s), " |&*%") +} + // LikeAny returns a single where condition matching the search words. func LikeAny(col, s string, keywords, exact bool) (wheres []string) { if s == "" { @@ -44,9 +49,9 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) { for _, w := range words { if wildcardThreshold > 0 && len(w) >= wildcardThreshold { - orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w))) + orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w))) } else { - orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w))) + orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w))) } if !keywords || !txt.ContainsASCIILetters(w) { @@ -56,7 +61,7 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) { singular := inflection.Singular(w) if singular != w { - orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(singular))) + orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(singular))) } } @@ -103,9 +108,9 @@ func LikeAll(col, s string, keywords, exact bool) (wheres []string) { for _, w := range words { if wildcardThreshold > 0 && len(w) >= wildcardThreshold { - wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w))) + wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w))) } else { - wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w))) + wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w))) } } @@ -140,9 +145,9 @@ func LikeAllNames(cols Cols, s string) (wheres []string) { for _, c := range cols { if strings.Contains(w, txt.Space) { - orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, SqlLike(w))) + orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, Like(w))) } else { - orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, SqlLike(w))) + orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, Like(w))) } } } @@ -189,7 +194,7 @@ func AnySlug(col, search, sep string) (where string) { } for _, w := range words { - wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, SqlLike(w))) + wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, Like(w))) } return strings.Join(wheres, " OR ") diff --git a/internal/search/like_test.go b/internal/search/like_test.go index c02b72728..3b5298410 100644 --- a/internal/search/like_test.go +++ b/internal/search/like_test.go @@ -12,6 +12,24 @@ import ( "github.com/stretchr/testify/assert" ) +func TestLike(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + assert.Equal(t, "", Like("")) + }) + t.Run("Special", func(t *testing.T) { + s := " ' \" \t \n %_''" + exp := "'' \"\" %_''''" + result := Like(s) + t.Logf("String..: %s", s) + t.Logf("Expected: %s", exp) + t.Logf("Result..: %s", result) + assert.Equal(t, exp, result) + }) + t.Run("Alnum", func(t *testing.T) { + assert.Equal(t, "123ABCabc", Like(" 123ABCabc%* ")) + }) +} + func TestLikeAny(t *testing.T) { t.Run("and_or_search", func(t *testing.T) { if w := LikeAny("k.keyword", "table spoon & usa | img json", true, false); len(w) != 2 { @@ -119,11 +137,11 @@ func TestLikeAnyWord(t *testing.T) { } }) t.Run("EscapeSql", func(t *testing.T) { - if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"usa"); len(w) != 2 { + if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"us'a"); len(w) != 2 { t.Fatalf("two where conditions expected: %#v", w) } else { assert.Equal(t, "k.keyword LIKE 'spoon%' OR k.keyword LIKE 'table%'", w[0]) - assert.Equal(t, "k.keyword LIKE '\\\"usa%'", w[1]) + assert.Equal(t, "k.keyword LIKE '\"\"us''a%'", w[1]) } }) } diff --git a/internal/search/photos_filter_albums_test.go b/internal/search/photos_filter_albums_test.go index 7c53b88a6..a23f5de7f 100644 --- a/internal/search/photos_filter_albums_test.go +++ b/internal/search/photos_filter_albums_test.go @@ -99,8 +99,7 @@ func TestPhotosFilterAlbums(t *testing.T) { } assert.Equal(t, len(photos), 0) }) - // TODO should not throw error - /*t.Run("albums middle '", func(t *testing.T) { + t.Run("AlbumsSingleQuote", func(t *testing.T) { var f form.SearchPhotos f.Albums = "Father's Day" @@ -113,7 +112,7 @@ func TestPhotosFilterAlbums(t *testing.T) { } assert.Greater(t, len(photos), 0) - })*/ + }) t.Run("albums end '", func(t *testing.T) { var f form.SearchPhotos @@ -190,8 +189,8 @@ func TestPhotosFilterAlbums(t *testing.T) { if err != nil { t.Fatal(err) } - // TODO: Needs review, variable number of results. + // TODO: Needs review, variable number of results. assert.GreaterOrEqual(t, len(photos), 0) }) t.Run("albums end |", func(t *testing.T) { @@ -340,8 +339,7 @@ func TestPhotosQueryAlbums(t *testing.T) { } assert.Equal(t, len(photos), 0) }) - //TODO should not throw error - /*t.Run("albums middle '", func(t *testing.T) { + t.Run("AlbumsQuerySingleQuote", func(t *testing.T) { var f form.SearchPhotos f.Query = "albums:\"Father's Day\"" @@ -354,7 +352,7 @@ func TestPhotosQueryAlbums(t *testing.T) { } assert.Greater(t, len(photos), 0) - })*/ + }) t.Run("albums end '", func(t *testing.T) { var f form.SearchPhotos @@ -431,8 +429,8 @@ func TestPhotosQueryAlbums(t *testing.T) { if err != nil { t.Fatal(err) } - // TODO: Needs review, variable number of results. + // TODO: Needs review, variable number of results. assert.GreaterOrEqual(t, len(photos), 0) }) t.Run("albums end |", func(t *testing.T) { diff --git a/internal/search/sql.go b/internal/search/sql.go deleted file mode 100644 index 901b2a091..000000000 --- a/internal/search/sql.go +++ /dev/null @@ -1,12 +0,0 @@ -package search - -import ( - "strings" - - "github.com/photoprism/photoprism/pkg/sanitize" -) - -// SqlLike escapes a string for use in an SQL query. -func SqlLike(s string) string { - return strings.Trim(sanitize.SqlString(s), " |&*%") -} diff --git a/internal/search/sql_test.go b/internal/search/sql_test.go deleted file mode 100644 index aa68d2718..000000000 --- a/internal/search/sql_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package search - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestSqlLike(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - assert.Equal(t, "", SqlLike("")) - }) - t.Run("Special", func(t *testing.T) { - s := "' \" \t \n %_''" - exp := "\\' \\\" %\\_\\'\\'" - result := SqlLike(s) - t.Logf("String..: %s", s) - t.Logf("Expected: %s", exp) - t.Logf("Result..: %s", result) - assert.Equal(t, exp, result) - }) - t.Run("Alnum", func(t *testing.T) { - assert.Equal(t, "123ABCabc", SqlLike(" 123ABCabc%* ")) - }) -} diff --git a/pkg/sanitize/sql.go b/pkg/sanitize/sql.go index 944fce02b..edf84bfe0 100644 --- a/pkg/sanitize/sql.go +++ b/pkg/sanitize/sql.go @@ -1,41 +1,53 @@ package sanitize -import ( - "bytes" -) +// SqlSpecial checks if the byte must be escaped/omitted in SQL. +func SqlSpecial(b byte) (special bool, omit bool) { + if b < 32 { + return true, true + } -// sqlSpecialBytes contains special bytes to escape in SQL search queries. -// see https://mariadb.com/kb/en/string-literals/ -var sqlSpecialBytes = []byte{34, 39, 92, 95} // ", ', \, _ + switch b { + case '"', '\'', '\\': + return true, false + default: + return false, false + } +} // SqlString escapes a string for use in an SQL query. func SqlString(s string) string { var i int for i = 0; i < len(s); i++ { - if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) { + if found, _ := SqlSpecial(s[i]); found { break } } - // No special characters found, return original string. + // Return if no special characters were found. if i >= len(s) { return s } b := make([]byte, 2*len(s)-i) + copy(b, s[:i]) + j := i + for ; i < len(s); i++ { - if s[i] < 31 { - // Ignore control chars. + if special, omit := SqlSpecial(s[i]); omit { + // Omit control characters. continue - } - if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) { - b[j] = '\\' + } else if special { + // Escape other special characters. + // see https://mariadb.com/kb/en/string-literals/ + b[j] = s[i] j++ } + b[j] = s[i] j++ } + return string(b[:j]) } diff --git a/pkg/sanitize/sql_test.go b/pkg/sanitize/sql_test.go index 18bfb47f5..c85efcd73 100644 --- a/pkg/sanitize/sql_test.go +++ b/pkg/sanitize/sql_test.go @@ -6,13 +6,72 @@ import ( "github.com/stretchr/testify/assert" ) +func TestSqlSpecial(t *testing.T) { + t.Run("Special", func(t *testing.T) { + if s, o := SqlSpecial(1); !s { + t.Error("char is special") + } else if !o { + t.Error("\" must be omitted") + } + + if s, o := SqlSpecial(31); !s { + t.Error("char is special") + } else if !o { + t.Error("\" must be omitted") + } + + if s, o := SqlSpecial('\\'); !s { + t.Error("\\ is special") + } else if o { + t.Error("\\ must not be omitted") + } + + if s, o := SqlSpecial('\''); !s { + t.Error("' is special") + } else if o { + t.Error("' must not be omitted") + } + + if s, o := SqlSpecial('"'); !s { + t.Error("\" is special") + } else if o { + t.Error("\" must not be omitted") + } + }) + t.Run("NotSpecial", func(t *testing.T) { + if s, o := SqlSpecial(32); s { + t.Error("space is not special") + } else if o { + t.Error("space must not be omitted") + } + + if s, o := SqlSpecial('A'); s { + t.Error("A is not special") + } else if o { + t.Error("A must not be omitted") + } + + if s, o := SqlSpecial('a'); s { + t.Error("a is not special") + } else if o { + t.Error("a must not be omitted") + } + + if s, o := SqlSpecial('_'); s { + t.Error("_ is not special") + } else if o { + t.Error("_ must not be omitted") + } + }) +} + func TestSqlString(t *testing.T) { t.Run("Empty", func(t *testing.T) { assert.Equal(t, "", SqlString("")) }) t.Run("Special", func(t *testing.T) { s := "' \" \t \n %_''" - exp := "\\' \\\" %\\_\\'\\'" + exp := "'' \"\" %_''''" result := SqlString(s) t.Logf("String..: %s", s) t.Logf("Expected: %s", exp)