Skip to content

Conversation

@isaac-chung
Copy link
Contributor

@isaac-chung isaac-chung commented Oct 8, 2023

What does this PR do?

Fixes # (issue) #26672

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@isaac-chung isaac-chung changed the title add early stopping logits processor Add early stopping logits processor Oct 8, 2023
@isaac-chung isaac-chung changed the title Add early stopping logits processor Add early stopping for Bark generation Oct 8, 2023
@isaac-chung isaac-chung changed the title Add early stopping for Bark generation Add early stopping for Bark generation via logits processor Oct 8, 2023
@isaac-chung isaac-chung marked this pull request as ready for review October 9, 2023 14:08
@isaac-chung
Copy link
Contributor Author

@ylacombe maybe we can continue the conversation here.

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @isaac-chung , thanks for the quick PR and for the good job!

I left a few comments here, let me know if you still have questions!

Other than that, for testing, I would add a test_XXXX on LogitsProcessorTest which first checks that the new logits processor behaves as expected with an hand-made example.
Ideally, we'd have another test on BarkSemanticModelTest, but I'm not sure how to proceed yet.
Do you have any ideas?

@isaac-chung
Copy link
Contributor Author

isaac-chung commented Oct 10, 2023

Ideally, we'd have another test on BarkSemanticModelTest, but I'm not sure how to proceed yet.
Do you have any ideas?

I'm not entirely sure. Maybe we could assert outputs from self.model.generate with the new arg somehow?

could be possibly passed to BarkModel.generate kwargs without causing issues

To confirm that we support this, maybe we should add to BarkModelIntegrationTests.test_generate_end_to_end_with_sub_models_args as well?

@ylacombe
Copy link
Contributor

Let's try to do both!

@isaac-chung
Copy link
Contributor Author

I think I managed to add to BarkModelIntegrationTests without issues. But I'd like to align on how to proceed with BarkSemanticModelTest. Specifically:

  1. Only a few tests assert the outputs. As I'm unsure what to expect, I might print the outputs and assert those
  2. I've been manually trying to fill in BarkSemanticGenerationConfig so that the generate() call does not fail. Not sure if there's a more efficient way.

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @isaac-chung, I've addressed 2. in the comments below! I'm not sure to understand point 1. though, could you expand on this a bit ?
thanks!

@isaac-chung
Copy link
Contributor Author

@ylacombe thanks! Regarding 1, let's take BarkModelIntegrationTests.test_generate_end_to_end_with_sub_models_args for example, the test does not assert any outputs and it simply runs .generate(). Would that be fine here?

@ylacombe
Copy link
Contributor

Let's try to find a case where the semantic model has to stop. You can get inspiration from that test:

def test_generate_semantic(self):
input_ids = self.inputs
# fmt: off
# check first ids
expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,]
# fmt: on
# greedy decoding
with torch.no_grad():
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)

So basically, an example where, the same seed, the last output tokens are different, do you think it's possible?

@isaac-chung
Copy link
Contributor Author

If we set min_eos_p to anything that's non-zero, we only get the eos_token (set to 10000 for open-end generation). Here is what passed.

    @slow
    def test_generate_semantic_early_stop(self):
        input_ids = self.inputs

        # fmt: off
        # check first ids
        expected_output_ids = [10000]
        # fmt: on

        self.semantic_generation_config.min_eos_p = 0.05

        # greedy decoding
        with torch.no_grad():
            output_ids = self.model.semantic.generate(
                **input_ids,
                do_sample=False,
                temperature=1.0,
                semantic_generation_config=self.semantic_generation_config,
            )

        self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)

Is that what you have in mind?

@ylacombe
Copy link
Contributor

Oh, that seems weird, have you tried with another generation strategy ? (i.e do_sample=True, temperature=...)? If you have the same results, it's probably on the logit processor side!

@ylacombe
Copy link
Contributor

Regarding using a stopping criteria, I don't think it's possible at the moment -> quoting #26672

@ArthurZucker
Copy link
Collaborator

It receives None because output_scores and return_dict are not properly set

@ylacombe
Copy link
Contributor

Yes of course, but don't you think users should have the liberty to set output_scores and return_dict as they want ?

@ArthurZucker
Copy link
Collaborator

For sure. So the goal here is by default to always stop early? (actually not returning the scores might be better in terms of memory ?)
What I mean is that the stopping criterias are meant to be used that way 😉

@ylacombe
Copy link
Contributor

For sure. So the goal here is by default to always stop early? (actually not returning the scores might be better in terms of memory ?) What I mean is that the stopping criterias are meant to be used that way 😉

Yes this is the goal here. Totally agree on the stopping criteria usage! Actually I haven't find a stopping criteria which uses scores yet, maybe because of the limitation of having to use return_dict_in_generate=True, output_scores=True. #23674 is a discussion on this and I believe this is under @gante's radar! What do you recommend in the meantime ?

@isaac-chung
Copy link
Contributor Author

Hey @ylacombe / @ArthurZucker , please let me know if there's anything else I can do to further this PR.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits. Other than that, looks good to me! Thank you for working on it 💪

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 💛

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating here @isaac-chung !
@ArthurZucker could you make a final review?

Two last demands on my side:

  1. are all the bark integration tests passing ? Could you make sure they are?
  2. At the risk of repeating myself, we still need a test to make sure that generated ids with min_eos_p>0 are shorter than generated ids without it.

@gante
Copy link
Member

gante commented Oct 25, 2023

btw, regarding it being a logits processor vs stopping criteria: it is my impression that we want to generate an EOS token under the conditions defined here. Since we want to generate a token, it has to be a logits processor.

(the main difference between them is that the stopping criteria stops generation right away, and doesn't add any new token -- for batched generation, this can make a big difference)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's just keep camel case and adresse the comments from @ylacombe !

@isaac-chung
Copy link
Contributor Author

@ylacombe I've run this command and all tests are passing ✅

RUN_SLOW=yes python -m unittest tests.models.bark.test_modeling_bark.BarkModelIntegrationTests

@ylacombe
Copy link
Contributor

LGTM ! Let's wait for all the check to pass and merge then! Thanks for the great work here and all the iterations!

@isaac-chung
Copy link
Contributor Author

Thank you all again for your guidance and patience 🙏 much appreciated.

@gante gante merged commit e2bffcf into huggingface:main Oct 27, 2023
@isaac-chung isaac-chung deleted the improve-bark-generation branch October 27, 2023 10:10
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
…ace#26675)

* add early stopping logits processor

* black formmated

* indent

* follow method signature

* actual logic

* check for None

* address comments on docstrings and method signature

* add unit test under `LogitsProcessorTest` wip

* unit test passing

* black formatted

* condition per sample

* add to BarkModelIntegrationTests

* wip BarkSemanticModelTest

* rename and add to kwargs handling

* not add to BarkSemanticModelTest

* correct logic and assert last outputs tokens different in test

* doc-builder style

* read from kwargs as well

* assert len of with less than that of without

* ruff

* add back seed and test case

* add original impl default suggestion

* doc-builder

* rename and use softmax

* switch back to LogitsProcessor and update docs wording

* camelCase and spelling and saving compute

* assert strictly less than

* assert less than

* expand test_generate_semantic_early_stop instead
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants