Skip to content

Commit d0465cc

Browse files
committed
Use paginated queryset for better memory performance
Signed-off-by: Keshav Priyadarshi <git@keshav.space>
1 parent 5c8770b commit d0465cc

3 files changed

Lines changed: 70 additions & 48 deletions

File tree

vulnerabilities/pipelines/flag_ghost_packages.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
#
99

1010
import logging
11+
from itertools import groupby
1112
from traceback import format_exc as traceback_format_exc
1213

1314
from aboutcode.pipeline import LoopProgress
14-
from fetchcode.package_versions import SUPPORTED_ECOSYSTEMS
15+
from fetchcode.package_versions import SUPPORTED_ECOSYSTEMS as FETCHCODE_SUPPORTED_ECOSYSTEMS
1516
from fetchcode.package_versions import versions
1617
from packageurl import PackageURL
1718
from univers.version_range import RANGE_CLASS_BY_SCHEMES
@@ -32,74 +33,76 @@ def flag_ghost_packages(self):
3233

3334

3435
def detect_and_flag_ghost_packages(logger=None):
35-
"""Use fetchcode to validate the package indeed exists upstream."""
36+
"""Check if packages are available upstream. If not, mark them as ghost package."""
3637
interesting_packages_qs = (
37-
Package.objects.filter(type__in=SUPPORTED_ECOSYSTEMS)
38+
Package.objects.order_by("type", "namespace", "name")
39+
.filter(type__in=FETCHCODE_SUPPORTED_ECOSYSTEMS)
3840
.filter(qualifiers="")
3941
.filter(subpath="")
4042
)
4143

42-
distinct_packages = interesting_packages_qs.values("type", "namespace", "name").distinct(
43-
"type", "namespace", "name"
44+
distinct_packages_count = (
45+
interesting_packages_qs.values("type", "namespace", "name")
46+
.distinct("type", "namespace", "name")
47+
.count()
4448
)
4549

46-
distinct_packages_count = distinct_packages.count()
47-
package_iterator = distinct_packages.iterator(chunk_size=2000)
48-
progress = LoopProgress(total_iterations=distinct_packages_count, logger=logger)
50+
grouped_packages = groupby(
51+
interesting_packages_qs.paginated(),
52+
key=lambda pkg: (pkg.type, pkg.namespace, pkg.name),
53+
)
4954

5055
ghost_package_count = 0
51-
52-
for package in progress.iter(package_iterator):
56+
progress = LoopProgress(total_iterations=distinct_packages_count, logger=logger)
57+
for type_namespace_name, packages in progress.iter(grouped_packages):
5358
ghost_package_count += flag_ghost_package(
54-
package_dict=package,
55-
interesting_packages_qs=interesting_packages_qs,
59+
base_purl=PackageURL(*type_namespace_name),
60+
packages=packages,
5661
logger=logger,
5762
)
5863

5964
if logger:
6065
logger(f"Successfully flagged {ghost_package_count:,d} ghost Packages")
6166

6267

63-
def flag_ghost_package(package_dict, interesting_packages_qs, logger=None):
68+
def flag_ghost_package(base_purl, packages, logger=None):
6469
"""
65-
Check if all the versions of the package described by `package_dict` (type, namespace, name)
66-
are available upstream. If they are not available, update the status to 'ghost'.
67-
Otherwise, update the status to 'valid'.
70+
Check if all the versions of the `purl` are available upstream.
71+
If they are not available, update the `is_ghost` to `True`.
6872
"""
69-
if not package_dict["type"] in RANGE_CLASS_BY_SCHEMES:
73+
if not base_purl.type in RANGE_CLASS_BY_SCHEMES:
7074
return 0
7175

72-
known_versions = get_versions(**package_dict, logger=logger)
73-
if not known_versions:
76+
known_versions = get_versions(purl=base_purl, logger=logger)
77+
# Skip if encounter error while fetching known versions
78+
if known_versions is None:
7479
return 0
7580

76-
version_class = RANGE_CLASS_BY_SCHEMES[package_dict["type"]].version_class
77-
package_versions = interesting_packages_qs.filter(**package_dict).filter(status="unknown")
78-
7981
ghost_packages = 0
80-
for pkg in package_versions:
82+
version_class = RANGE_CLASS_BY_SCHEMES[base_purl.type].version_class
83+
for pkg in packages:
84+
pkg.is_ghost = False
8185
if version_class(pkg.version) not in known_versions:
82-
pkg.status = "ghost"
83-
pkg.save()
86+
pkg.is_ghost = True
8487
ghost_packages += 1
8588

86-
valid_package_versions = package_versions.exclude(status="ghost")
87-
valid_package_versions.update(status="valid")
89+
if logger:
90+
logger(f"Flagging ghost package {pkg.purl!s}", level=logging.DEBUG)
91+
pkg.save()
8892

8993
return ghost_packages
9094

9195

92-
def get_versions(type, namespace, name, logger=None):
93-
"""Return set of known versions for the given package type, namespace, and name."""
94-
versionless_purl = PackageURL(type=type, namespace=namespace, name=name)
95-
version_class = RANGE_CLASS_BY_SCHEMES[type].version_class
96+
def get_versions(purl, logger=None):
97+
"""Return set of known versions for the given purl."""
98+
version_class = RANGE_CLASS_BY_SCHEMES[purl.type].version_class
9699

97100
try:
98-
return {version_class(v.value) for v in versions(str(versionless_purl))}
101+
return {version_class(v.value) for v in versions(str(purl))}
99102
except Exception as e:
100103
if logger:
101104
logger(
102-
f"Error while fetching known versions for {versionless_purl!r}: {e!r} \n {traceback_format_exc()}",
105+
f"Error while fetching known versions for {purl!s}: {e!r} \n {traceback_format_exc()}",
103106
level=logging.ERROR,
104107
)
105108
return
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#
2+
# Copyright (c) nexB Inc. and others. All rights reserved.
3+
# VulnerableCode is a trademark of nexB Inc.
4+
# SPDX-License-Identifier: Apache-2.0
5+
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
6+
# See https://github.com/nexB/vulnerablecode for support or download.
7+
# See https://aboutcode.org for more information about nexB OSS projects.
8+
#
9+
10+
import io
11+
12+
13+
class TestLogger:
14+
buffer = io.StringIO()
15+
16+
def write(self, msg, level=None):
17+
self.buffer.write(msg)
18+
19+
def getvalue(self):
20+
return self.buffer.getvalue()

vulnerabilities/tests/pipelines/test_flag_ghost_packages.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
# See https://aboutcode.org for more information about nexB OSS projects.
88
#
99

10-
import io
10+
1111
from pathlib import Path
1212
from unittest import mock
1313

1414
from django.test import TestCase
1515
from fetchcode.package_versions import PackageVersion
16+
from packageurl import PackageURL
1617

1718
from vulnerabilities.models import Package
1819
from vulnerabilities.pipelines import flag_ghost_packages
20+
from vulnerabilities.tests.pipelines import TestLogger
1921

2022

2123
class FlagGhostPackagePipelineTest(TestCase):
@@ -30,20 +32,16 @@ def test_flag_ghost_package(self, mock_fetchcode_versions):
3032
PackageVersion(value="2.3.0"),
3133
]
3234
interesting_packages_qs = Package.objects.all()
33-
target_package = {
34-
"type": "pypi",
35-
"namespace": "",
36-
"name": "foo",
37-
}
35+
base_purl = PackageURL(type="pypi", name="foo")
3836

39-
self.assertEqual(0, Package.objects.filter(status="ghost").count())
37+
self.assertEqual(0, Package.objects.filter(is_ghost=True).count())
4038

4139
flagged_package_count = flag_ghost_packages.flag_ghost_package(
42-
package_dict=target_package,
43-
interesting_packages_qs=interesting_packages_qs,
40+
base_purl=base_purl,
41+
packages=interesting_packages_qs,
4442
)
4543
self.assertEqual(1, flagged_package_count)
46-
self.assertEqual(1, Package.objects.filter(status="ghost").count())
44+
self.assertEqual(1, Package.objects.filter(is_ghost=True).count())
4745

4846
@mock.patch("vulnerabilities.pipelines.flag_ghost_packages.versions")
4947
def test_detect_and_flag_ghost_packages(self, mock_fetchcode_versions):
@@ -62,11 +60,12 @@ def test_detect_and_flag_ghost_packages(self, mock_fetchcode_versions):
6260
]
6361

6462
self.assertEqual(3, Package.objects.count())
65-
self.assertEqual(0, Package.objects.filter(status="ghost").count())
63+
self.assertEqual(0, Package.objects.filter(is_ghost=True).count())
64+
65+
logger = TestLogger()
6666

67-
buffer = io.StringIO()
68-
flag_ghost_packages.detect_and_flag_ghost_packages(logger=buffer.write)
67+
flag_ghost_packages.detect_and_flag_ghost_packages(logger=logger.write)
6968
expected = "Successfully flagged 1 ghost Packages"
7069

71-
self.assertIn(expected, buffer.getvalue())
72-
self.assertEqual(1, Package.objects.filter(status="ghost").count())
70+
self.assertIn(expected, logger.getvalue())
71+
self.assertEqual(1, Package.objects.filter(is_ghost=True).count())

0 commit comments

Comments
 (0)