Skip to content

Commit e29e861

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7cbb4d8 commit e29e861

1 file changed

Lines changed: 90 additions & 71 deletions

File tree

modules/idc_dataset.ipynb

Lines changed: 90 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,14 @@
7070
"\n",
7171
"# Restart runtime after installing ITK (required for ITK to load properly)\n",
7272
"import sys\n",
73+
"\n",
7374
"if \"google.colab\" in sys.modules:\n",
7475
" try:\n",
7576
" import itk # noqa: F401\n",
7677
" except ImportError:\n",
7778
" print(\"Restarting runtime to load ITK...\")\n",
7879
" import os\n",
80+
"\n",
7981
" os.kill(os.getpid(), 9)"
8082
]
8183
},
@@ -170,14 +172,16 @@
170172
"\n",
171173
"# Query the primary 'index' table to see the overall scale of IDC.\n",
172174
"# 'collection_id' groups series by dataset; 'PatientID' is the DICOM patient identifier.\n",
173-
"stats = client.sql_query(\"\"\"\n",
175+
"stats = client.sql_query(\n",
176+
" \"\"\"\n",
174177
" SELECT COUNT(DISTINCT collection_id) as collections,\n",
175178
" COUNT(DISTINCT PatientID) as patients,\n",
176179
" COUNT(DISTINCT SeriesInstanceUID) as series,\n",
177180
" SUM(series_size_MB) as size_mb,\n",
178181
" ROUND(SUM(series_size_MB) / 1e6, 2) as size_tb\n",
179182
" FROM index\n",
180-
"\"\"\")\n",
183+
"\"\"\"\n",
184+
")\n",
181185
"row = stats.iloc[0]\n",
182186
"print(f\"Collections: {row['collections']}, Patients: {row['patients']}, Total size: {row['size_tb']}TB\")"
183187
]
@@ -215,7 +219,8 @@
215219
"\n",
216220
"# Join collections_index with the per-series index to filter by both cancer type and modality.\n",
217221
"# The 'index' table has Modality per series; collections_index has CancerTypes per collection.\n",
218-
"lung_collections = client.sql_query(\"\"\"\n",
222+
"lung_collections = client.sql_query(\n",
223+
" \"\"\"\n",
219224
" SELECT c.collection_id, c.Subjects, c.CancerTypes,\n",
220225
" COUNT(DISTINCT CASE WHEN i.Modality = 'CT' THEN i.SeriesInstanceUID END) as ct_series\n",
221226
" FROM collections_index c\n",
@@ -225,7 +230,8 @@
225230
" HAVING ct_series > 0\n",
226231
" ORDER BY c.Subjects DESC\n",
227232
" LIMIT 5\n",
228-
"\"\"\")\n",
233+
"\"\"\"\n",
234+
")\n",
229235
"print(\"Lung CT collections:\")\n",
230236
"print(lung_collections.to_string(index=False))"
231237
]
@@ -260,14 +266,16 @@
260266
"# Select a few small CT series that form well-formed 3D volumes.\n",
261267
"# We join 'index' (series metadata) with 'volume_geometry_index' (geometry flags).\n",
262268
"# ORDER BY series_size_MB is not here, but LIMIT 3 keeps the demo download small.\n",
263-
"series_df = client.sql_query(\"\"\"\n",
269+
"series_df = client.sql_query(\n",
270+
" \"\"\"\n",
264271
" SELECT index.SeriesInstanceUID, PatientID, Modality,\n",
265272
" ROUND(series_size_MB, 2) as size_mb\n",
266273
" FROM index\n",
267274
" JOIN volume_geometry_index USING (SeriesInstanceUID)\n",
268275
" WHERE regularly_spaced_3d_volume = TRUE AND Modality = 'CT'\n",
269276
" LIMIT 3\n",
270-
"\"\"\")\n",
277+
"\"\"\"\n",
278+
")\n",
271279
"print(f\"Found {len(series_df)} CT series\")"
272280
]
273281
},
@@ -327,18 +335,14 @@
327335
"source": [
328336
"# Create a temporary directory to hold downloaded DICOM files.\n",
329337
"data_dir = tempfile.mkdtemp(prefix=\"idc_monai_\")\n",
330-
"series_uids = list(series_df['SeriesInstanceUID'])\n",
338+
"series_uids = list(series_df[\"SeriesInstanceUID\"])\n",
331339
"\n",
332340
"print(f\"Downloading {len(series_uids)} series...\")\n",
333341
"# download_from_selection() accepts a list of SeriesInstanceUIDs and fetches\n",
334342
"# all DICOM files for those series from IDC's GCS buckets.\n",
335343
"# dirTemplate=\"%SeriesInstanceUID\" puts each series in its own subdirectory —\n",
336344
"# required because MONAI's ITKReader reads a directory to reconstruct a 3D volume.\n",
337-
"client.download_from_selection(\n",
338-
" seriesInstanceUID=series_uids,\n",
339-
" downloadDir=data_dir,\n",
340-
" dirTemplate=\"%SeriesInstanceUID\"\n",
341-
")\n",
345+
"client.download_from_selection(seriesInstanceUID=series_uids, downloadDir=data_dir, dirTemplate=\"%SeriesInstanceUID\")\n",
342346
"print(\"Done!\")"
343347
]
344348
},
@@ -383,14 +387,15 @@
383387
"source": [
384388
"# Define transforms for CT preprocessing\n",
385389
"# Use ITKReader explicitly to load DICOM series from directories\n",
386-
"transforms = Compose([\n",
387-
" LoadImaged(keys=[\"image\"], reader=ITKReader()),\n",
388-
" EnsureChannelFirstd(keys=[\"image\"]),\n",
389-
" Orientationd(keys=[\"image\"], axcodes=\"RAS\"),\n",
390-
" Spacingd(keys=[\"image\"], pixdim=(1.5, 1.5, 2.0)),\n",
391-
" ScaleIntensityRanged(keys=[\"image\"], a_min=-175, a_max=250,\n",
392-
" b_min=0.0, b_max=1.0, clip=True),\n",
393-
"])\n",
390+
"transforms = Compose(\n",
391+
" [\n",
392+
" LoadImaged(keys=[\"image\"], reader=ITKReader()),\n",
393+
" EnsureChannelFirstd(keys=[\"image\"]),\n",
394+
" Orientationd(keys=[\"image\"], axcodes=\"RAS\"),\n",
395+
" Spacingd(keys=[\"image\"], pixdim=(1.5, 1.5, 2.0)),\n",
396+
" ScaleIntensityRanged(keys=[\"image\"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),\n",
397+
" ]\n",
398+
")\n",
394399
"\n",
395400
"# Create dataset\n",
396401
"data_dicts = [{\"image\": os.path.join(data_dir, uid)} for uid in series_uids]\n",
@@ -435,15 +440,14 @@
435440
}
436441
],
437442
"source": [
438-
"\n",
439-
"image = sample['image'][0]\n",
443+
"image = sample[\"image\"][0]\n",
440444
"z = image.shape[2] // 2\n",
441445
"\n",
442446
"plt.figure(figsize=(6, 6))\n",
443-
"plt.imshow(image[:, :, z].T, cmap='gray', origin='lower')\n",
444-
"plt.title(f'CT from IDC (slice {z})')\n",
445-
"plt.axis('off')\n",
446-
"plt.show()\n"
447+
"plt.imshow(image[:, :, z].T, cmap=\"gray\", origin=\"lower\")\n",
448+
"plt.title(f\"CT from IDC (slice {z})\")\n",
449+
"plt.axis(\"off\")\n",
450+
"plt.show()"
447451
]
448452
},
449453
{
@@ -504,7 +508,8 @@
504508
"# TotalSegmentator is an AI model that auto-segments 100+ anatomical structures.\n",
505509
"# The JOIN links each segmentation back to its source CT via segmented_SeriesInstanceUID.\n",
506510
"# We sort by image size (ASC) so the demo downloads the smallest available pair.\n",
507-
"paired = client.sql_query(\"\"\"\n",
511+
"paired = client.sql_query(\n",
512+
" \"\"\"\n",
508513
" SELECT src.SeriesInstanceUID as image_uid,\n",
509514
" seg.SeriesInstanceUID as seg_uid,\n",
510515
" src.collection_id, seg.total_segments,\n",
@@ -515,7 +520,8 @@
515520
" AND seg.AlgorithmName LIKE '%TotalSegmentator%'\n",
516521
" ORDER BY src.series_size_MB ASC\n",
517522
" LIMIT 3\n",
518-
"\"\"\")\n",
523+
"\"\"\"\n",
524+
")\n",
519525
"print(\"CT with TotalSegmentator segmentations:\")\n",
520526
"print(paired.to_string(index=False))"
521527
]
@@ -538,9 +544,9 @@
538544
"\n",
539545
"print(\"Downloading image and segmentation pair...\")\n",
540546
"client.download_from_selection(\n",
541-
" seriesInstanceUID=[demo_pair['image_uid'], demo_pair['seg_uid']],\n",
547+
" seriesInstanceUID=[demo_pair[\"image_uid\"], demo_pair[\"seg_uid\"]],\n",
542548
" downloadDir=seg_dir,\n",
543-
" dirTemplate=\"%SeriesInstanceUID\"\n",
549+
" dirTemplate=\"%SeriesInstanceUID\",\n",
544550
")\n",
545551
"print(\"Done!\")"
546552
]
@@ -600,7 +606,7 @@
600606
" physical axis i. ITK affine formula:\n",
601607
" world_lps = D @ diag(spacing) @ voxel + origin\n",
602608
" \"\"\"\n",
603-
" lps_to_ras = np.diag([-1., -1., 1.])\n",
609+
" lps_to_ras = np.diag([-1.0, -1.0, 1.0])\n",
604610
" affine = np.eye(4)\n",
605611
" affine[:3, :3] = lps_to_ras @ direction @ np.diag(spacing)\n",
606612
" affine[:3, 3] = lps_to_ras @ origin\n",
@@ -645,20 +651,24 @@
645651
"\n",
646652
"\n",
647653
"# Load CT with MONAI's ITKReader\n",
648-
"ct_transforms = Compose([\n",
649-
" LoadImaged(keys=[\"image\"], reader=ITKReader()),\n",
650-
" EnsureChannelFirstd(keys=[\"image\"]),\n",
651-
"])\n",
654+
"ct_transforms = Compose(\n",
655+
" [\n",
656+
" LoadImaged(keys=[\"image\"], reader=ITKReader()),\n",
657+
" EnsureChannelFirstd(keys=[\"image\"]),\n",
658+
" ]\n",
659+
")\n",
652660
"\n",
653661
"# Load SEG with our custom LoadDicomSegd\n",
654-
"seg_transforms = Compose([\n",
655-
" LoadDicomSegd(keys=[\"label\"]),\n",
656-
" EnsureChannelFirstd(keys=[\"label\"]),\n",
657-
"])\n",
662+
"seg_transforms = Compose(\n",
663+
" [\n",
664+
" LoadDicomSegd(keys=[\"label\"]),\n",
665+
" EnsureChannelFirstd(keys=[\"label\"]),\n",
666+
" ]\n",
667+
")\n",
658668
"\n",
659669
"# Load both\n",
660-
"image_path = os.path.join(seg_dir, demo_pair['image_uid'])\n",
661-
"seg_path = os.path.join(seg_dir, demo_pair['seg_uid'])\n",
670+
"image_path = os.path.join(seg_dir, demo_pair[\"image_uid\"])\n",
671+
"seg_path = os.path.join(seg_dir, demo_pair[\"seg_uid\"])\n",
662672
"\n",
663673
"ct_data = ct_transforms({\"image\": image_path})\n",
664674
"seg_data = seg_transforms({\"label\": seg_path})\n",
@@ -782,22 +792,24 @@
782792
" modifier_seq = seg.get(\"SegmentedPropertyTypeModifierCodeSequence\", {})\n",
783793
" modifier = modifier_seq.get(\"CodeMeaning\", \"\") if modifier_seq else \"\"\n",
784794
"\n",
785-
" segments.append({\n",
786-
" \"label_id\": seg.get(\"labelID\", 0),\n",
787-
" \"name\": seg.get(\"SegmentLabel\", \"Unknown\"),\n",
788-
" \"category\": category,\n",
789-
" \"type\": seg_type,\n",
790-
" \"type_code\": f\"{coding_scheme}:{type_code}\" if type_code else \"\",\n",
791-
" \"modifier\": modifier,\n",
792-
" \"color_rgb\": seg.get(\"recommendedDisplayRGBValue\", [128, 128, 128]),\n",
793-
" \"algorithm\": seg.get(\"SegmentAlgorithmName\", \"\"),\n",
794-
" })\n",
795+
" segments.append(\n",
796+
" {\n",
797+
" \"label_id\": seg.get(\"labelID\", 0),\n",
798+
" \"name\": seg.get(\"SegmentLabel\", \"Unknown\"),\n",
799+
" \"category\": category,\n",
800+
" \"type\": seg_type,\n",
801+
" \"type_code\": f\"{coding_scheme}:{type_code}\" if type_code else \"\",\n",
802+
" \"modifier\": modifier,\n",
803+
" \"color_rgb\": seg.get(\"recommendedDisplayRGBValue\", [128, 128, 128]),\n",
804+
" \"algorithm\": seg.get(\"SegmentAlgorithmName\", \"\"),\n",
805+
" }\n",
806+
" )\n",
795807
"\n",
796808
" return sorted(segments, key=lambda x: x[\"label_id\"])\n",
797809
"\n",
798810
"\n",
799811
"# Extract segment information\n",
800-
"overlay_info = seg_data.get('label_meta_dict', {}).get('overlay_info', {})\n",
812+
"overlay_info = seg_data.get(\"label_meta_dict\", {}).get(\"overlay_info\", {})\n",
801813
"segments = get_segment_info(overlay_info)\n",
802814
"\n",
803815
"print(f\"Found {len(segments)} segments in DICOM SEG:\\n\")\n",
@@ -849,9 +861,7 @@
849861
"# Find SEG z-slice at the midpoint of the labeled region\n",
850862
"z_slices_with_labels = np.where(seg_np.sum(axis=(0, 1)) > 0)[0]\n",
851863
"seg_mid_z = (\n",
852-
" int(z_slices_with_labels[len(z_slices_with_labels) // 2])\n",
853-
" if len(z_slices_with_labels) > 0\n",
854-
" else seg_np.shape[2] // 2\n",
864+
" int(z_slices_with_labels[len(z_slices_with_labels) // 2]) if len(z_slices_with_labels) > 0 else seg_np.shape[2] // 2\n",
855865
")\n",
856866
"\n",
857867
"# Map SEG midpoint z to CT z via world coordinates\n",
@@ -915,7 +925,7 @@
915925
" for seg in all_segments:\n",
916926
" label_id = seg.get(\"labelID\", 0)\n",
917927
" rgb = seg.get(\"recommendedDisplayRGBValue\", [128, 128, 128])\n",
918-
" colors[label_id] = [rgb[0]/255, rgb[1]/255, rgb[2]/255, 1.0]\n",
928+
" colors[label_id] = [rgb[0] / 255, rgb[1] / 255, rgb[2] / 255, 1.0]\n",
919929
"\n",
920930
" return ListedColormap(colors)\n",
921931
"\n",
@@ -928,26 +938,33 @@
928938
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
929939
"\n",
930940
"# CT image only\n",
931-
"axes[0].imshow(ct_np[:, :, z_mid].T, cmap='gray', origin='lower', vmin=-1000, vmax=500)\n",
932-
"axes[0].set_title(f'CT Image (RAS, z={z_mid})')\n",
933-
"axes[0].axis('off')\n",
941+
"axes[0].imshow(ct_np[:, :, z_mid].T, cmap=\"gray\", origin=\"lower\", vmin=-1000, vmax=500)\n",
942+
"axes[0].set_title(f\"CT Image (RAS, z={z_mid})\")\n",
943+
"axes[0].axis(\"off\")\n",
934944
"\n",
935945
"# Segmentation only (using DICOM SEG colors)\n",
936-
"axes[1].imshow(seg_np[:, :, seg_mid_z].T, cmap=seg_cmap, origin='lower',\n",
937-
" vmin=0, vmax=len(seg_cmap.colors)-1, interpolation='nearest')\n",
938-
"axes[1].set_title(f'Segmentation (RAS, z={seg_mid_z})\\n({int(seg_np.max())} labels, DICOM SEG colors)')\n",
939-
"axes[1].axis('off')\n",
946+
"axes[1].imshow(\n",
947+
" seg_np[:, :, seg_mid_z].T,\n",
948+
" cmap=seg_cmap,\n",
949+
" origin=\"lower\",\n",
950+
" vmin=0,\n",
951+
" vmax=len(seg_cmap.colors) - 1,\n",
952+
" interpolation=\"nearest\",\n",
953+
")\n",
954+
"axes[1].set_title(f\"Segmentation (RAS, z={seg_mid_z})\\n({int(seg_np.max())} labels, DICOM SEG colors)\")\n",
955+
"axes[1].axis(\"off\")\n",
940956
"\n",
941957
"# Overlay — both arrays are in RAS, z-slices matched via world coordinates\n",
942-
"axes[2].imshow(ct_np[:, :, z_mid].T, cmap='gray', origin='lower', vmin=-1000, vmax=500)\n",
958+
"axes[2].imshow(ct_np[:, :, z_mid].T, cmap=\"gray\", origin=\"lower\", vmin=-1000, vmax=500)\n",
943959
"seg_slice = seg_np[:, :, seg_mid_z]\n",
944960
"mask = np.ma.masked_where(seg_slice == 0, seg_slice)\n",
945-
"axes[2].imshow(mask.T, cmap=seg_cmap, alpha=0.6, origin='lower',\n",
946-
" vmin=0, vmax=len(seg_cmap.colors)-1, interpolation='nearest')\n",
947-
"axes[2].set_title('Overlay (RAS)\\n(DICOM SEG colors)')\n",
948-
"axes[2].axis('off')\n",
961+
"axes[2].imshow(\n",
962+
" mask.T, cmap=seg_cmap, alpha=0.6, origin=\"lower\", vmin=0, vmax=len(seg_cmap.colors) - 1, interpolation=\"nearest\"\n",
963+
")\n",
964+
"axes[2].set_title(\"Overlay (RAS)\\n(DICOM SEG colors)\")\n",
965+
"axes[2].axis(\"off\")\n",
949966
"\n",
950-
"plt.suptitle('CT + TotalSegmentator Segmentation', fontsize=14)\n",
967+
"plt.suptitle(\"CT + TotalSegmentator Segmentation\", fontsize=14)\n",
951968
"plt.tight_layout()\n",
952969
"plt.show()\n",
953970
"\n",
@@ -994,11 +1011,13 @@
9941011
"source": [
9951012
"# Always check licenses before use\n",
9961013
"uid_list = \", \".join(f\"'{uid}'\" for uid in series_uids)\n",
997-
"licenses = client.sql_query(f\"\"\"\n",
1014+
"licenses = client.sql_query(\n",
1015+
" f\"\"\"\n",
9981016
" SELECT license_short_name, COUNT(*) as count\n",
9991017
" FROM index WHERE SeriesInstanceUID IN ({uid_list})\n",
10001018
" GROUP BY license_short_name\n",
1001-
"\"\"\")\n",
1019+
"\"\"\"\n",
1020+
")\n",
10021021
"print(\"Licenses:\")\n",
10031022
"print(licenses.to_string(index=False))"
10041023
]

0 commit comments

Comments
 (0)