- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
TF port of the Segment Anything Model (SAM) #22970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| The documentation is not available anymore as the PR was closed or merged. | 
| def flatten(input, start_dim=0, end_dim=-1): | ||
| # Replicates the behavior of torch.flatten in TF | ||
|  | ||
| # If end_dim or start_dim is negative, count them from the end | ||
| if end_dim < 0: | ||
| end_dim += input.shape.rank | ||
| if start_dim < 0: | ||
| start_dim += input.shape.rank | ||
|  | ||
| if start_dim == end_dim: | ||
| return input | ||
|  | ||
| in_shape = tf.shape(input) | ||
| flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) | ||
| out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0) | ||
| return tf.reshape(input, out_shape) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥲
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no idea why I didn't do this before now!
|  | ||
| return output_masks | ||
|  | ||
| def post_process_masks_tf( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we started including separate post-processing ops in native TensorFlow? I thought they were NumPy only. This is indeed nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure about this - there's probably some code duplication in the processor I can remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preprocessing are all in numpy - this hasn't been extended to postprocessing methods yet. Mainly because I haven't dared tackle torch.nn.functional.interpolate; partly because we haven't needed to yet.
That said - please don't have post_processing_xxx_tf! We don't use decode_tf for our tokenizers ;)
Could you rework the methods so there's a single post_process_xxx method and hidden framework-specifc methods? i.e.
    def post_process_masks(self, masks, ...,):
        if is_torch_tensor(masks):
            return self._post_process_masks_pt(...)
        if is_tf_tensor(masks):
            return self._post_process_masks_tf(...) 
        ...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! And sorry - I basically rushed through the processor code so I could get to the bit I was hype about (benchmarking GPT-4's translations)
b9dd5a4    to
    b1f61bd      
    Compare
  
    | This is now almost ready to go and the code should be ready for review! Remaining issues: 
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there two different processing files, one of them not being imported everywhere?
The common tests should not be changed to have a higher tolerance, just override the right tests in proper test file.
Also cc @amyeroberts since you reviewed the PyTorch model extensively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have two functions that do the exact same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved as part of the general processor refactor!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this is leftover from debugging...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved as part of the general processor refactor! (also oops, sorry)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shh, it's gone now. We don't talk about processing_tf_sam
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have a separate test file to test the same class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also gone now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change the tolerance for this model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a tolerance argument to the base tests triggered the test to run in other models, which caused this test to fail. I'll investigate and see if it's necessary, though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use more descriptive variable names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was copied straight from the PyTorch code, but on reflection I could probably refactor the whole thing out, because it was only there to deal with different memory orderings (whereas TensorFlow tensors are always contiguous and always have standard C memory ordering)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! I refactored the functional_layernorm function to handle alternate axes, and then just called that instead of this manual layernorm. Model output is unchanged and all integration tests still pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace this by?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarified that comment!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To address.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never figured this out, but it's the same in Torch, and both models give equivalent outputs. @ArthurZucker do you know why this weight is non-trainable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't find any reference to this random embedding in the paper (in fact, the paper always mentions learned positional embeddings), but the same pattern is in the SAM codebase
This meme is all I can think of
| Thanks for the review - about half of the comments relate to the processor code, which is definitely in need of a refactor, yes. Working on that now! | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
Left some general comments - mainly wrt the processing code. I'd like for there to be as little TF/PT specific code if possible. For postprocessing it's OK, as a lot of postprocessing is still pytorch specific but for preprocessing it should be (as much as possible) framework agnostic.
For the processor, can you add pt_tf cross checks to make sure that TF postprocessed outputs are equivalent to the PT ones?
|  | ||
| return output_masks | ||
|  | ||
| def post_process_masks_tf( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preprocessing are all in numpy - this hasn't been extended to postprocessing methods yet. Mainly because I haven't dared tackle torch.nn.functional.interpolate; partly because we haven't needed to yet.
That said - please don't have post_processing_xxx_tf! We don't use decode_tf for our tokenizers ;)
Could you rework the methods so there's a single post_process_xxx method and hidden framework-specifc methods? i.e.
    def post_process_masks(self, masks, ...,):
        if is_torch_tensor(masks):
            return self._post_process_masks_pt(...)
        if is_tf_tensor(masks):
            return self._post_process_masks_tf(...) 
        ...| # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise | ||
| # to generate masks during test | ||
| def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): | ||
| def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to add the tol argument here? Unless necessary, I'd avoid resetting the tol default in all the methods so we only need to update in one place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored this and reverted all the changes in the common tests
| if output_hidden_states: | ||
| vision_hidden_states = vision_outputs[1] | ||
| if output_attentions: | ||
| vision_attentions = vision_outputs[-1] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we instead pass in return_dict=True to self.vision_encoder and then explicitly access the values from the names? I'm not a big fan of accessing from indexes here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! (Also changed in the original PT code)
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have these arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not clear about this one! Aren't these arguments common across most of our models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so? Only SAM has get_image_embeddings and all other get_xxx_embeddings as far as I can tell just take self
|  | ||
| # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only | ||
| # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced | ||
| # it with an explicit shape check to avoid data-dependent control flow which breaks XLA. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:)
| @amyeroberts @sgugger I refactored all the changes to the common tests, and just overrode  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! 💪
Additional general comment: it seems like it is missing the Keras training argument all around (call and in the dropout layers)... but on the other hand, SAM is not trainable. Still, in case we add a training script, I'd add this quick future-proof change :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't find any reference to this random embedding in the paper (in fact, the paper always mentions learned positional embeddings), but the same pattern is in the SAM codebase
This meme is all I can think of
76cebb9    to
    17536e4      
    Compare
  
    | @gante I think all comments are now addressed, and I added  All comments from @amyeroberts and @sgugger should be addressed too - are you okay with going ahead and merging now once tests pass? | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work on this. @amyeroberts could you also have a look before this is merged?
| points (`torch.Tensor`, **optional**): | ||
| point coordinates and labels to embed. | ||
| boxes (`torch.Tensor`, **optionnal**): | ||
| boxes (`torch.Tensor`, **optional**): | ||
| boxes to embed | ||
| masks (`torch.Tensor`, **optionnal**): | ||
| masks (`torch.Tensor`, **optional**): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are touching this, can you put the optionals in italics and not bold ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| return_dict=return_dict, | ||
| return_dict=True, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cannot be forced as return_dict breaks jit compilation. This change needs reverting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad - this was my suggestion, sorry @Rocketknight1!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| values. | ||
| """ | ||
|  | ||
| def __init__(self, config, downsample_rate=None, **kwargs) -> None: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The -> None make zero sense to me as a type annotation (I know it's what PEP says, but the init returns an instance of the class). Since there are no type annotations elsewhere, maybe just remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! (for all classes across both the PT and TF files)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! 🔥
Thanks for iterating, and in particular for spending the time to add equivalence tests for the processor and keep the image processing code tidy with the two frameworks 🤗
|  | ||
| self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy())) | ||
|  | ||
| def test_image_processor_equivalence(self): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗
| return_dict=return_dict, | ||
| return_dict=True, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad - this was my suggestion, sorry @Rocketknight1!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tolerance seems pretty high 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually okay - the values for the scores are very large (usually in the range 5-30). A tolerance of 2e-4 for numbers that big is quite tight!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note in #23376 - input_boxes should be a list of list of ints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is just copying from the PT implementation - but it would be great to add to the docstring info about what's returned as there's many objects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be honest that I don't understand it too well, lol. I'll leave that for a follow-up on the Torch end and copy the strings whenever they do it 😅
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so? Only SAM has get_image_embeddings and all other get_xxx_embeddings as far as I can tell just take self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
latyer norm layer here should take eps from config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layer norm layers here should take eps from config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PyTorch version doesn't, and just uses the 1e-6 default kwarg value!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's hope it's not too experimental 😬
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tnp has been around since 2.4, I think we're safe!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ha! for TF I doubt it ;)
Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
875cc35    to
    3902969      
    Compare
  
    | I think comments are addressed now - are we okay to merge? | 
| I'm treating silence as agreement, merging! | 
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <[email protected]> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <[email protected]> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <[email protected]> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <[email protected]> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
This is a first draft of the SAM port - will update this PR as I port tests and make sure everything is working okay. It's also a first proof-of-concept for full GPT-4 auto-translation from PyTorch: The entire
modeling_tf_sam.pyfile was converted from PyTorch by GPT-4 with the exception of the imports at the top, because I haven't written a prompt for those yet.Update: I checked over all of the code and fixed the issues in the GPT port. Equivalence tests all look good! This is almost ready to merge, but there are a few small issues left:
channels_firstdoesn't actually work on CPU in TF