Skip to content

Commit ca9a61d

Browse files
committed
Update download urls from figshare to s3 bucket
1 parent 493ed25 commit ca9a61d

6 files changed

Lines changed: 366 additions & 63 deletions

File tree

cebra/data/assets.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# limitations under the License.
2121
#
2222

23+
import gzip
2324
import hashlib
2425
import re
2526
import warnings
@@ -140,3 +141,128 @@ def calculate_checksum(file_path: str) -> str:
140141
for chunk in iter(lambda: file.read(4096), b""):
141142
checksum.update(chunk)
142143
return checksum.hexdigest()
144+
145+
146+
def download_and_extract_gzipped_file(url: str,
147+
expected_checksum: str,
148+
gzipped_checksum: str,
149+
location: str,
150+
file_name: str,
151+
retry_count: int = 0) -> Optional[str]:
152+
"""Download a gzipped file from the given URL, verify checksums, and extract.
153+
154+
This function downloads a gzipped file, verifies the checksum of the gzipped
155+
file, extracts it, and then verifies the checksum of the extracted file.
156+
157+
Args:
158+
url: The URL to download the gzipped file from.
159+
expected_checksum: The expected MD5 checksum of the extracted file.
160+
gzipped_checksum: The expected MD5 checksum of the gzipped file.
161+
location: The directory where the file will be saved.
162+
file_name: The name of the final extracted file (without .gz extension).
163+
retry_count: The number of retry attempts (default: 0).
164+
165+
Returns:
166+
The path of the extracted file if successful, None otherwise.
167+
168+
Raises:
169+
RuntimeError: If the maximum retry count is exceeded.
170+
requests.HTTPError: If the download fails.
171+
"""
172+
173+
# Check if the final extracted file already exists with correct checksum
174+
location_path = Path(location)
175+
final_file_path = location_path / file_name
176+
177+
if final_file_path.exists():
178+
existing_checksum = calculate_checksum(final_file_path)
179+
if existing_checksum == expected_checksum:
180+
return final_file_path
181+
182+
if retry_count >= _MAX_RETRY_COUNT:
183+
raise RuntimeError(
184+
f"Exceeded maximum retry count ({_MAX_RETRY_COUNT}). "
185+
f"Unable to download the file from {url}")
186+
187+
# Create the directory and any necessary parent directories
188+
location_path.mkdir(parents=True, exist_ok=True)
189+
190+
# Download the gzipped file
191+
gz_file_path = location_path / f"{file_name}.gz"
192+
193+
response = requests.get(url, stream=True)
194+
195+
# Check if the request was successful
196+
if response.status_code != 200:
197+
raise requests.HTTPError(
198+
f"Error occurred while downloading the file. Response code: {response.status_code}"
199+
)
200+
201+
total_size = int(response.headers.get("Content-Length", 0))
202+
checksum = hashlib.md5() # create checksum for gzipped file
203+
204+
# Download the gzipped file
205+
with open(gz_file_path, "wb") as file:
206+
with tqdm.tqdm(total=total_size,
207+
unit="B",
208+
unit_scale=True,
209+
desc="Downloading") as progress_bar:
210+
for data in response.iter_content(chunk_size=1024):
211+
file.write(data)
212+
checksum.update(data)
213+
progress_bar.update(len(data))
214+
215+
downloaded_gz_checksum = checksum.hexdigest()
216+
217+
# Verify gzipped file checksum
218+
if downloaded_gz_checksum != gzipped_checksum:
219+
warnings.warn(
220+
f"Gzipped file checksum verification failed. Deleting '{gz_file_path}'."
221+
)
222+
gz_file_path.unlink()
223+
warnings.warn("Gzipped file deleted. Retrying download...")
224+
return download_and_extract_gzipped_file(url, expected_checksum,
225+
gzipped_checksum, location,
226+
file_name, retry_count + 1)
227+
228+
print("Gzipped file checksum verified. Extracting...")
229+
230+
# Extract the gzipped file
231+
try:
232+
with gzip.open(gz_file_path, 'rb') as f_in:
233+
with open(final_file_path, 'wb') as f_out:
234+
# Extract with progress (estimate based on typical compression ratio)
235+
extracted_size = 0
236+
while True:
237+
chunk = f_in.read(8192)
238+
if not chunk:
239+
break
240+
f_out.write(chunk)
241+
extracted_size += len(chunk)
242+
except Exception as e:
243+
warnings.warn(f"Extraction failed: {e}. Deleting files and retrying...")
244+
if gz_file_path.exists():
245+
gz_file_path.unlink()
246+
if final_file_path.exists():
247+
final_file_path.unlink()
248+
return download_and_extract_gzipped_file(url, expected_checksum,
249+
gzipped_checksum, location,
250+
file_name, retry_count + 1)
251+
252+
# Verify extracted file checksum
253+
extracted_checksum = calculate_checksum(final_file_path)
254+
if extracted_checksum != expected_checksum:
255+
warnings.warn(
256+
"Extracted file checksum verification failed. Deleting files.")
257+
gz_file_path.unlink()
258+
final_file_path.unlink()
259+
warnings.warn("Files deleted. Retrying download...")
260+
return download_and_extract_gzipped_file(url, expected_checksum,
261+
gzipped_checksum, location,
262+
file_name, retry_count + 1)
263+
264+
# Clean up the gzipped file after successful extraction
265+
gz_file_path.unlink()
266+
267+
print(f"Extraction complete. Dataset saved in '{final_file_path}'")
268+
return final_file_path

cebra/data/base.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self,
5555
download=False,
5656
data_url=None,
5757
data_checksum=None,
58+
gzipped_checksum=None,
5859
location=None,
5960
file_name=None):
6061

@@ -64,6 +65,7 @@ def __init__(self,
6465
self.download = download
6566
self.data_url = data_url
6667
self.data_checksum = data_checksum
68+
self.gzipped_checksum = gzipped_checksum
6769
self.location = location
6870
self.file_name = file_name
6971

@@ -78,11 +80,21 @@ def __init__(self,
7880
"Missing data checksum. Please provide the checksum to verify the data integrity."
7981
)
8082

81-
cebra_data_assets.download_file_with_progress_bar(
82-
url=self.data_url,
83-
expected_checksum=self.data_checksum,
84-
location=self.location,
85-
file_name=self.file_name)
83+
# Use gzipped download if gzipped_checksum is provided
84+
if self.gzipped_checksum is not None:
85+
cebra_data_assets.download_and_extract_gzipped_file(
86+
url=self.data_url,
87+
expected_checksum=self.data_checksum,
88+
gzipped_checksum=self.gzipped_checksum,
89+
location=self.location,
90+
file_name=self.file_name)
91+
else:
92+
# Fall back to legacy download for backward compatibility
93+
cebra_data_assets.download_file_with_progress_bar(
94+
url=self.data_url,
95+
expected_checksum=self.data_checksum,
96+
location=self.location,
97+
file_name=self.file_name)
8698

8799
@property
88100
@abc.abstractmethod

cebra/datasets/hippocampus.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,35 @@
5050
rat_dataset_urls = {
5151
"achilles": {
5252
"url":
53-
"https://figshare.com/ndownloader/files/40849463?private_link=9f91576cbbcc8b0d8828",
53+
"https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/achilles.jl.gz",
5454
"checksum":
55-
"c52f9b55cbc23c66d57f3842214058b8"
55+
"c52f9b55cbc23c66d57f3842214058b8",
56+
"gzipped_checksum":
57+
"5d7b243e07b24c387e5412cd5ff46f0b"
5658
},
5759
"buddy": {
5860
"url":
59-
"https://figshare.com/ndownloader/files/40849460?private_link=9f91576cbbcc8b0d8828",
61+
"https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/buddy.jl.gz",
6062
"checksum":
61-
"36341322907708c466871bf04bc133c2"
63+
"36341322907708c466871bf04bc133c2",
64+
"gzipped_checksum":
65+
"339290585be2188f48a176f05aaf5df6"
6266
},
6367
"cicero": {
6468
"url":
65-
"https://figshare.com/ndownloader/files/40849457?private_link=9f91576cbbcc8b0d8828",
69+
"https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/cicero.jl.gz",
6670
"checksum":
67-
"a83b02dbdc884fdd7e53df362499d42f"
71+
"a83b02dbdc884fdd7e53df362499d42f",
72+
"gzipped_checksum":
73+
"f262a87d2e59f164cb404cd410015f3a"
6874
},
6975
"gatsby": {
7076
"url":
71-
"https://figshare.com/ndownloader/files/40849454?private_link=9f91576cbbcc8b0d8828",
77+
"https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/gatsby.jl.gz",
7278
"checksum":
73-
"2b889da48178b3155011c12555342813"
79+
"2b889da48178b3155011c12555342813",
80+
"gzipped_checksum":
81+
"564e431c19e55db2286a9d64c86a94c4"
7482
}
7583
}
7684

@@ -95,11 +103,13 @@ def __init__(self, name="achilles", root=_DEFAULT_DATADIR, download=True):
95103
location = pathlib.Path(root) / "rat_hippocampus"
96104
file_path = location / f"{name}.jl"
97105

98-
super().__init__(download=download,
99-
data_url=rat_dataset_urls[name]["url"],
100-
data_checksum=rat_dataset_urls[name]["checksum"],
101-
location=location,
102-
file_name=f"{name}.jl")
106+
super().__init__(
107+
download=download,
108+
data_url=rat_dataset_urls[name]["url"],
109+
data_checksum=rat_dataset_urls[name]["checksum"],
110+
gzipped_checksum=rat_dataset_urls[name].get("gzipped_checksum"),
111+
location=location,
112+
file_name=f"{name}.jl")
103113

104114
data = joblib.load(file_path)
105115
self.neural = torch.from_numpy(data["spikes"]).float()

cebra/datasets/monkey_reaching.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -160,75 +160,99 @@ def _get_info(trial_info, data):
160160
monkey_reaching_urls = {
161161
"all_all.jl": {
162162
"url":
163-
"https://figshare.com/ndownloader/files/41668764?private_link=6fa4ee74a8f465ec7914",
163+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_all.jl.gz",
164164
"checksum":
165-
"dea556301fa4fafa86e28cf8621cab5a"
165+
"dea556301fa4fafa86e28cf8621cab5a",
166+
"gzipped_checksum":
167+
"399abc6e9ef0b23a0d6d057c6f508939"
166168
},
167169
"all_train.jl": {
168170
"url":
169-
"https://figshare.com/ndownloader/files/41668752?private_link=6fa4ee74a8f465ec7914",
171+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_train.jl.gz",
170172
"checksum":
171-
"e280e4cd86969e6fd8bfd3a8f402b2fe"
173+
"e280e4cd86969e6fd8bfd3a8f402b2fe",
174+
"gzipped_checksum":
175+
"eb52c8641fe83ae2a278b372ddec5f69"
172176
},
173177
"all_test.jl": {
174178
"url":
175-
"https://figshare.com/ndownloader/files/41668761?private_link=6fa4ee74a8f465ec7914",
179+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_test.jl.gz",
176180
"checksum":
177-
"25d3ff2c15014db8b8bf2543482ae881"
181+
"25d3ff2c15014db8b8bf2543482ae881",
182+
"gzipped_checksum":
183+
"7688245cf15e0b92503af943ce9f66aa"
178184
},
179185
"all_valid.jl": {
180186
"url":
181-
"https://figshare.com/ndownloader/files/41668755?private_link=6fa4ee74a8f465ec7914",
187+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_valid.jl.gz",
182188
"checksum":
183-
"8cd25169d31f83ae01b03f7b1b939723"
189+
"8cd25169d31f83ae01b03f7b1b939723",
190+
"gzipped_checksum":
191+
"b169fc008b4d092fe2a1b7e006cd17a7"
184192
},
185193
"active_all.jl": {
186194
"url":
187-
"https://figshare.com/ndownloader/files/41668776?private_link=6fa4ee74a8f465ec7914",
195+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_all.jl.gz",
188196
"checksum":
189-
"c626acea5062122f5a68ef18d3e45e51"
197+
"c626acea5062122f5a68ef18d3e45e51",
198+
"gzipped_checksum":
199+
"b7b86e2ae00bb71341de8fc352dae097"
190200
},
191201
"active_train.jl": {
192202
"url":
193-
"https://figshare.com/ndownloader/files/41668770?private_link=6fa4ee74a8f465ec7914",
203+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_train.jl.gz",
194204
"checksum":
195-
"72a48056691078eee22c36c1992b1d37"
205+
"72a48056691078eee22c36c1992b1d37",
206+
"gzipped_checksum":
207+
"56687c633efcbff6c56bbcfa35597565"
196208
},
197209
"active_test.jl": {
198210
"url":
199-
"https://figshare.com/ndownloader/files/41668773?private_link=6fa4ee74a8f465ec7914",
211+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_test.jl.gz",
200212
"checksum":
201-
"35b7e060008a8722c536584c4748f2ea"
213+
"35b7e060008a8722c536584c4748f2ea",
214+
"gzipped_checksum":
215+
"2057ef1846908a69486a61895d1198e8"
202216
},
203217
"active_valid.jl": {
204218
"url":
205-
"https://figshare.com/ndownloader/files/41668767?private_link=6fa4ee74a8f465ec7914",
219+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_valid.jl.gz",
206220
"checksum":
207-
"dd58eb1e589361b4132f34b22af56b79"
221+
"dd58eb1e589361b4132f34b22af56b79",
222+
"gzipped_checksum":
223+
"60b8e418f234877351fe36f1efc169ad"
208224
},
209225
"passive_all.jl": {
210226
"url":
211-
"https://figshare.com/ndownloader/files/41668758?private_link=6fa4ee74a8f465ec7914",
227+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_all.jl.gz",
212228
"checksum":
213-
"bbb1bc9d8eec583a46f6673470fc98ad"
229+
"bbb1bc9d8eec583a46f6673470fc98ad",
230+
"gzipped_checksum":
231+
"afb257efa0cac3ccd69ec80478d63691"
214232
},
215233
"passive_train.jl": {
216234
"url":
217-
"https://figshare.com/ndownloader/files/41668743?private_link=6fa4ee74a8f465ec7914",
235+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_train.jl.gz",
218236
"checksum":
219-
"f22e05a69f70e18ba823a0a89162a45c"
237+
"f22e05a69f70e18ba823a0a89162a45c",
238+
"gzipped_checksum":
239+
"24d98d7d41a52591f838c41fe83dc2c6"
220240
},
221241
"passive_test.jl": {
222242
"url":
223-
"https://figshare.com/ndownloader/files/41668746?private_link=6fa4ee74a8f465ec7914",
243+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_test.jl.gz",
224244
"checksum":
225-
"42453ae3e4fd27d82d297f78c13cd6b7"
245+
"42453ae3e4fd27d82d297f78c13cd6b7",
246+
"gzipped_checksum":
247+
"f1ff4e9b7c4a0d7fa9dcd271893f57ab"
226248
},
227249
"passive_valid.jl": {
228250
"url":
229-
"https://figshare.com/ndownloader/files/41668749?private_link=6fa4ee74a8f465ec7914",
251+
"https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_valid.jl.gz",
230252
"checksum":
231-
"2dcc10c27631b95a075eaa2d2297bb4a"
253+
"2dcc10c27631b95a075eaa2d2297bb4a",
254+
"gzipped_checksum":
255+
"311fcb6a3e86022f12d78828f7bd29d5"
232256
}
233257
}
234258

@@ -270,6 +294,8 @@ def __init__(self,
270294
data_url=monkey_reaching_urls[f"{self.load_session}_all.jl"]["url"],
271295
data_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"]
272296
["checksum"],
297+
gzipped_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"]
298+
.get("gzipped_checksum"),
273299
location=self.path,
274300
file_name=f"{self.load_session}_all.jl",
275301
)
@@ -297,6 +323,8 @@ def split(self, split):
297323
["url"],
298324
data_checksum=monkey_reaching_urls[
299325
f"{self.load_session}_{split}.jl"]["checksum"],
326+
gzipped_checksum=monkey_reaching_urls[
327+
f"{self.load_session}_{split}.jl"].get("gzipped_checksum"),
300328
location=self.path,
301329
file_name=f"{self.load_session}_{split}.jl",
302330
)

0 commit comments

Comments
 (0)