Fix NuGet Package API for $filter with Id equality (#31188)

Fixes issue when running `choco info pkgname` where `pkgname` is also a
substring of another package Id.

Relates to #31168

---

This might fix the issue linked, but I'd like to test it with more choco
commands before closing the issue in case I find other problems if
that's ok.

---------

Co-authored-by: KN4CK3R <admin@oldschoolhack.me>
This commit is contained in:
Thomas Desveaux 2024-06-04 08:45:56 +02:00 committed by GitHub
parent 4f9b8b397c
commit c888c933a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 115 additions and 35 deletions

@ -96,20 +96,34 @@ func FeedCapabilityResource(ctx *context.Context) {
xmlResponse(ctx, http.StatusOK, Metadata) xmlResponse(ctx, http.StatusOK, Metadata)
} }
var searchTermExtract = regexp.MustCompile(`'([^']+)'`) var (
searchTermExtract = regexp.MustCompile(`'([^']+)'`)
searchTermExact = regexp.MustCompile(`\s+eq\s+'`)
)
func getSearchTerm(ctx *context.Context) string { func getSearchTerm(ctx *context.Context) packages_model.SearchValue {
searchTerm := strings.Trim(ctx.FormTrim("searchTerm"), "'") searchTerm := strings.Trim(ctx.FormTrim("searchTerm"), "'")
if searchTerm == "" { if searchTerm != "" {
return packages_model.SearchValue{
Value: searchTerm,
ExactMatch: false,
}
}
// $filter contains a query like: // $filter contains a query like:
// (((Id ne null) and substringof('microsoft',tolower(Id))) // (((Id ne null) and substringof('microsoft',tolower(Id)))
// https://www.odata.org/documentation/odata-version-2-0/uri-conventions/ section 4.5
// We don't support these queries, just extract the search term. // We don't support these queries, just extract the search term.
match := searchTermExtract.FindStringSubmatch(ctx.FormTrim("$filter")) filter := ctx.FormTrim("$filter")
match := searchTermExtract.FindStringSubmatch(filter)
if len(match) == 2 { if len(match) == 2 {
searchTerm = strings.TrimSpace(match[1]) return packages_model.SearchValue{
Value: strings.TrimSpace(match[1]),
ExactMatch: searchTermExact.MatchString(filter),
} }
} }
return searchTerm
return packages_model.SearchValue{}
} }
// https://github.com/NuGet/NuGet.Client/blob/dev/src/NuGet.Core/NuGet.Protocol/LegacyFeed/V2FeedQueryBuilder.cs // https://github.com/NuGet/NuGet.Client/blob/dev/src/NuGet.Core/NuGet.Protocol/LegacyFeed/V2FeedQueryBuilder.cs
@ -120,9 +134,7 @@ func SearchServiceV2(ctx *context.Context) {
pvs, total, err := packages_model.SearchLatestVersions(ctx, &packages_model.PackageSearchOptions{ pvs, total, err := packages_model.SearchLatestVersions(ctx, &packages_model.PackageSearchOptions{
OwnerID: ctx.Package.Owner.ID, OwnerID: ctx.Package.Owner.ID,
Type: packages_model.TypeNuGet, Type: packages_model.TypeNuGet,
Name: packages_model.SearchValue{ Name: getSearchTerm(ctx),
Value: getSearchTerm(ctx),
},
IsInternal: optional.Some(false), IsInternal: optional.Some(false),
Paginator: paginator, Paginator: paginator,
}) })
@ -170,9 +182,7 @@ func SearchServiceV2(ctx *context.Context) {
func SearchServiceV2Count(ctx *context.Context) { func SearchServiceV2Count(ctx *context.Context) {
count, err := nuget_model.CountPackages(ctx, &packages_model.PackageSearchOptions{ count, err := nuget_model.CountPackages(ctx, &packages_model.PackageSearchOptions{
OwnerID: ctx.Package.Owner.ID, OwnerID: ctx.Package.Owner.ID,
Name: packages_model.SearchValue{ Name: getSearchTerm(ctx),
Value: getSearchTerm(ctx),
},
IsInternal: optional.Some(false), IsInternal: optional.Some(false),
}) })
if err != nil { if err != nil {

@ -434,17 +434,28 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
Take int Take int
ExpectedTotal int64 ExpectedTotal int64
ExpectedResults int ExpectedResults int
ExpectedExactMatch bool
}{ }{
{"", 0, 0, 1, 1}, {"", 0, 0, 4, 4, false},
{"", 0, 10, 1, 1}, {"", 0, 10, 4, 4, false},
{"gitea", 0, 10, 0, 0}, {"gitea", 0, 10, 0, 0, false},
{"test", 0, 10, 1, 1}, {"test", 0, 10, 1, 1, false},
{"test", 1, 10, 1, 0}, {"test", 1, 10, 1, 0, false},
{"almost.similar", 0, 0, 3, 3, true},
} }
req := NewRequestWithBody(t, "PUT", url, createPackage(packageName, "1.0.99")). fakePackages := []string{
packageName,
"almost.similar.dependency",
"almost.similar",
"almost.similar.dependant",
}
for _, fakePackageName := range fakePackages {
req := NewRequestWithBody(t, "PUT", url, createPackage(fakePackageName, "1.0.99")).
AddBasicAuth(user.Name) AddBasicAuth(user.Name)
MakeRequest(t, req, http.StatusCreated) MakeRequest(t, req, http.StatusCreated)
}
t.Run("v2", func(t *testing.T) { t.Run("v2", func(t *testing.T) {
t.Run("Search()", func(t *testing.T) { t.Run("Search()", func(t *testing.T) {
@ -491,6 +502,63 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
} }
}) })
t.Run("Packages()", func(t *testing.T) {
defer tests.PrintCurrentTest(t)()
t.Run("substringof", func(t *testing.T) {
defer tests.PrintCurrentTest(t)()
for i, c := range cases {
req := NewRequest(t, "GET", fmt.Sprintf("%s/Packages()?$filter=substringof('%s',tolower(Id))&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
AddBasicAuth(user.Name)
resp := MakeRequest(t, req, http.StatusOK)
var result FeedResponse
decodeXML(t, resp, &result)
assert.Equal(t, c.ExpectedTotal, result.Count, "case %d: unexpected total hits", i)
assert.Len(t, result.Entries, c.ExpectedResults, "case %d: unexpected result count", i)
req = NewRequest(t, "GET", fmt.Sprintf("%s/Packages()/$count?$filter=substringof('%s',tolower(Id))&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
AddBasicAuth(user.Name)
resp = MakeRequest(t, req, http.StatusOK)
assert.Equal(t, strconv.FormatInt(c.ExpectedTotal, 10), resp.Body.String(), "case %d: unexpected total hits", i)
}
})
t.Run("IdEq", func(t *testing.T) {
defer tests.PrintCurrentTest(t)()
for i, c := range cases {
if c.Query == "" {
// Ignore the `tolower(Id) eq ''` as it's unlikely to happen
continue
}
req := NewRequest(t, "GET", fmt.Sprintf("%s/Packages()?$filter=(tolower(Id) eq '%s')&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
AddBasicAuth(user.Name)
resp := MakeRequest(t, req, http.StatusOK)
var result FeedResponse
decodeXML(t, resp, &result)
expectedCount := 0
if c.ExpectedExactMatch {
expectedCount = 1
}
assert.Equal(t, int64(expectedCount), result.Count, "case %d: unexpected total hits", i)
assert.Len(t, result.Entries, expectedCount, "case %d: unexpected result count", i)
req = NewRequest(t, "GET", fmt.Sprintf("%s/Packages()/$count?$filter=(tolower(Id) eq '%s')&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
AddBasicAuth(user.Name)
resp = MakeRequest(t, req, http.StatusOK)
assert.Equal(t, strconv.FormatInt(int64(expectedCount), 10), resp.Body.String(), "case %d: unexpected total hits", i)
}
})
})
t.Run("Next", func(t *testing.T) { t.Run("Next", func(t *testing.T) {
req := NewRequest(t, "GET", fmt.Sprintf("%s/Search()?searchTerm='test'&$skip=0&$top=1", url)). req := NewRequest(t, "GET", fmt.Sprintf("%s/Search()?searchTerm='test'&$skip=0&$top=1", url)).
AddBasicAuth(user.Name) AddBasicAuth(user.Name)
@ -548,9 +616,11 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
}) })
}) })
req = NewRequest(t, "DELETE", fmt.Sprintf("%s/%s/%s", url, packageName, "1.0.99")). for _, fakePackageName := range fakePackages {
req := NewRequest(t, "DELETE", fmt.Sprintf("%s/%s/%s", url, fakePackageName, "1.0.99")).
AddBasicAuth(user.Name) AddBasicAuth(user.Name)
MakeRequest(t, req, http.StatusNoContent) MakeRequest(t, req, http.StatusNoContent)
}
}) })
t.Run("RegistrationService", func(t *testing.T) { t.Run("RegistrationService", func(t *testing.T) {