Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate SAM (segment anything) encoder with Unet #757

Open
wants to merge 26 commits into
base: master
Choose a base branch
from

Conversation

Rusteam
Copy link

@Rusteam Rusteam commented May 3, 2023

Closes #756

Added:

  • SAM to models
  • 3 SAM backbones (vit_h, vit_b and vit_l) to encoders
  • unittests and docs for SAM

Changed:

  • flake8 pre-commit repo to github (current) and version to latest

@Rusteam Rusteam changed the title integrate SAM (segment anything) model and encoders Draft: integrate SAM (segment anything) model and encoders May 3, 2023
@Rusteam Rusteam changed the title Draft: integrate SAM (segment anything) model and encoders integrate SAM (segment anything) model and encoders May 5, 2023
@Rusteam
Copy link
Author

Rusteam commented May 5, 2023

hi @qubvel is there any update on this?
I've just trained a model using this branch and it worked.

@Rusteam
Copy link
Author

Rusteam commented May 14, 2023

@Rusteam is the code merged into the main repo??i want to use this model to fine-tune my data?

It's not. Not sure if @qubvel has had a chance to look into this PR. You could use my fork in the meanwhile. And do let me know how your fine-Tuning goes because I haven't had much success so far.

@Rusteam
Copy link
Author

Rusteam commented May 15, 2023

@Rusteam how to train a model ,can u give some outlines?as author is not responding pls help me to train a model.. I have sent u an mail pls give a look

make sure you install this package from my fork pip instal git+https://github.com/Rusteam/segmentation_models.pytorch.git@sam and then initialize your model as usual create_model("SAM", "sam-vit_b", encoder_weights=None, **kwargs) and run your training. You could pass weights="sa-1b" in kwargs if you want to fine-tune from pre-trained weights.

So far I have been able to train the model, but I can't say it's learning. I'm still struggling there. Also I cannot fit more than 1 sample per batch on a 32gb gpu with a 512 input size.

@ccl-private
Copy link

@Rusteam
Copy link
Author

Rusteam commented May 16, 2023

thanks for sharing, I'll try it if my current approach does not work. I've able to get some learning with this transformers notebook

@qubvel
Copy link
Owner

qubvel commented May 17, 2023

Hi @Rusteam, thanks a lot for your contribution and sorry for the delay, I am going to review the request and will let you know

@Rusteam
Copy link
Author

Rusteam commented May 17, 2023

Hey hey hey. While this solution worked I can't say the model was able to learn on my data. We might need to use the version before my ddp adjustments or make the model handle points and boxes as inputs, or use Sam image encoder with unet or other architectures.

from typing import Optional, Union, List, Tuple

import torch
from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a pip package? probably need to add to reqs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just added it to reqs, or should we make it optional?

@qubvel
Copy link
Owner

qubvel commented May 17, 2023

Yes, I was actually thinking about just pre-trained encoder integration, did you test it?

@Rusteam
Copy link
Author

Rusteam commented May 18, 2023

can we use this model to train on custom data??

@qubvel It didn't work with Unet yet, but I can make it work. Which models would be essential to integrate?

@Rusteam
Copy link
Author

Rusteam commented May 18, 2023

@Rusteam @qubvel can we use this model to train on custom data??

that was my intention as well, but I was unable to make it learn without passing box/point prompts. However, when passing a prompt along with input image, it does learn. We might need to integrate multiple inputs to forward() call for it to work, or just use sam's image encoder with other arches like Unet

@qubvel
Copy link
Owner

qubvel commented Jun 14, 2023

@Rusteam thank you! I will review it in a few days

meanwhile, did you try to train the updated version? how its behave with the update?

@Rusteam
Copy link
Author

Rusteam commented Jun 14, 2023

Performance on my data has degraded. My data is small (few K sample) and it's geo-spatial. I believe It might add an improvement on normal datasets with a decent size. Alternately, we can introduce a Boolean parameter whether to use skip connections or not.

@siddpiku
Copy link

After I do !pip install git+https://github.com/Rusteam/segmentation_models.pytorch.git@sam , I tried the following command -
smp.create_model("SAM", "sam-vit_b", encoder_weights=None, weights = 'sa-1b', image_size = 256)
But it seems to throw an error
KeyError: "Wrong architecture type SAM. Available options are: ['unet', 'unetplusplus', 'manet', 'linknet', 'fpn', 'pspnet', 'deeplabv3', 'deeplabv3plus', 'pan']"
Am I doing the installation incorrect?

@Rusteam
Copy link
Author

Rusteam commented Jun 15, 2023

@Rusteam model = smp.SAM( encoder_name="sam-vit_b" encoder_weights="sa-1b", weights=None, image_size=64, decoder_multimask_output=decoder_multiclass_output, classes=n_classes, )
what we should add in the weights variable?

we've decided to remove smp.SAM for the moment and only keep sam image encoder with Unet. Refrain from using it or use at your own risk. It was supposed to be smp.SAM(..., encoder_weights=None, weights="sa-1b")

@siddpiku note this

try this if you want to use SAM's image encoder with Unet:

smp.create_model("Unet", "sam-vit_b", encoder_weights="sa-1b", encoder_depth=4, decoder_channels=[256, 128, 64, 32])`

@siddpiku
Copy link

When I run
smp.create_model("Unet", "sam-vit_b", encoder_weights="sa-1b", encoder_depth=4, decoder_channels=[256, 128, 64, 32])
It throws an error:
KeyError: "Wrong encoder name sam-vit_b, supported encoders:....

@Rusteam
Copy link
Author

Rusteam commented Jun 16, 2023

Are you sure you've have installed from my fork?

@siddpiku
Copy link

I think so,

I run:
Screenshot 2023-06-15 at 11 23 26 PM

@sushmanthreddy
Copy link

@Rusteam if u dont mind can u pls share ur training notebook??
which u have trained??

@Rusteam
Copy link
Author

Rusteam commented Jun 22, 2023

My code is in private repo with .py files, can't share.

Can you try re-installing the package? Make sure to delete existing one first with pip uninstall ...

@siddpiku
Copy link

siddpiku commented Jul 5, 2023

The following worked for me:
-git clone the sam branch,
-modify the sam.py file like below to get rid of the errors:
-change def forward(self, x: torch.Tensor) -> list[torch.Tensor]: to def forward(self, x: torch.Tensor):

  • import segmentation_models_pytorch as smp (python file in same folder as git clone branch)
  • smp.create_model("Unet", "sam-vit_b", encoder_weights="sa-1b", encoder_depth=4, decoder_channels=[256, 128, 64, 32])
  • Try training
    What did not work -
  • For me, I tried fine tuning with 2 RTX A6000 GPU with batch size of 2 on the ACDC data (https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html) but my Dice loss did not improve after 700 epochs. (Maybe some other setting works, but I did not have time to recreate it)

@Rusteam
Copy link
Author

Rusteam commented Jul 13, 2023

@qubvel hey any updates?

@Rusab
Copy link

Rusab commented Sep 6, 2023

Please add this, this library hasn't have new features for a long time

Copy link

github-actions bot commented Nov 6, 2023

This PR is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 15 days.

@github-actions github-actions bot added the Stale label Nov 6, 2023
@csaroff
Copy link

csaroff commented Nov 17, 2023

Is this PR ready?

@github-actions github-actions bot removed the Stale label Nov 18, 2023
@Rusteam
Copy link
Author

Rusteam commented Nov 18, 2023

It's ready.

@17SIM
Copy link

17SIM commented Nov 21, 2023

The current PR seems to work with image with the size of 1024x1024 only.

@Rusteam
Copy link
Author

Rusteam commented Nov 21, 2023

Yes, as the original Sam model

Copy link

This PR is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 15 days.

@github-actions github-actions bot added the Stale label Jan 21, 2024
@Stinosko
Copy link

Any progress on this?

@github-actions github-actions bot removed the Stale label Jan 29, 2024
@Rusab
Copy link

Rusab commented Jan 29, 2024

Why is the library dying? no new updates in a long time

Copy link

This PR is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 15 days.

@github-actions github-actions bot added the Stale label Mar 30, 2024
@Rusteam
Copy link
Author

Rusteam commented Mar 30, 2024

@qubvel can you merge this? It did work

@github-actions github-actions bot removed the Stale label Mar 31, 2024
@isaaccorley
Copy link

@Rusteam Consider contributing this to TorchSeg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SAM backbone integration