-
Notifications
You must be signed in to change notification settings - Fork 471
Add tutorial notebook for generating embeddings from pretrained models #2959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a comprehensive tutorial notebook demonstrating how to extract fixed-length embeddings from pretrained models in TorchGeo. The tutorial covers using DOFA and ResNet-18 pretrained models to generate embeddings from EuroSAT imagery and evaluating them with k-NN classifiers.
- Adds a complete Jupyter notebook tutorial for pretrained embedding extraction
- Demonstrates usage of two different pretrained models (DOFA and ResNet-18) with specific preprocessing requirements
- Includes visualization and evaluation of embeddings using PCA and k-NN classification
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
docs/tutorials/pretrained_embeddings.ipynb | New tutorial notebook demonstrating embedding extraction from pretrained models |
docs/tutorials/basic_usage.rst | Updated documentation index to include the new embeddings tutorial |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
This LGTM. Not sure if @adamjstewart has any other nits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I always have nits 😈
But seriously though, this is mostly good, glad we finally have this!
"source": [ | ||
"# On Colab, this ensures the latest TorchGeo is available.\n", | ||
"\n", | ||
"%pip install torchgeo" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"%pip install torchgeo" | |
"%pip install torchgeo scikit-learn tqdm" |
tqdm is pulled in by torch, but scikit-learn isn't guaranteed to be installed, so we definitely want to add these
"\n", | ||
"root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", | ||
"datamodule = EuroSAT100DataModule(\n", | ||
" root=root, batch_size=10, num_workers=2, download=True, bands=('B02', 'B03', 'B04')\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why RGB-only?
"# Fit a k-NN classifier on DOFA train embeddings and evaluate on validation embeddings.\n", | ||
"# This gives a quick, label-efficient baseline without fine-tuning.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of these would be better suited as markdown cells. I don't love having multiple code cells in a row without explanations.
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now let's do the same thing with a ResNet18 model pretrained on Sentinel-2 RGB imagery (from the SSL4EO paper).\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity (and without running the notebook myself), which did better?
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pca = PCA(n_components=2, whiten=True)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t-SNE might be cool as well, but more work.
" embeddings = model.forward_features(x)\n", | ||
" embeddings = torch.mean(\n", | ||
" embeddings, dim=(-2, -1)\n", | ||
" ) # global average pooling over the spatial dims\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of these comments can be moved from in-line to their own line to reduce formatting changes
"\n", | ||
"train_dl = datamodule.train_dataloader()\n", | ||
"val_dl = datamodule.val_dataloader()\n", | ||
"test_dl = datamodule.test_dataloader()" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test set isn't being used at the moment. Would it be better to only use val or test since we aren't doing fine-tuning?
Notebook link: https://github.com/torchgeo/torchgeo/blob/embedding_tutorial/docs/tutorials/pretrained_embeddings.ipynb