diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index a3ed744f467d..047e00bc3128 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -724,6 +724,7 @@ def _upload_modified_files( token: Optional[Union[bool, str]] = None, create_pr: bool = False, revision: str = None, + commit_description: str = None, ): """ Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`. @@ -778,6 +779,7 @@ def _upload_modified_files( repo_id=repo_id, operations=operations, commit_message=commit_message, + commit_description=commit_description, token=token, create_pr=create_pr, revision=revision, @@ -794,6 +796,7 @@ def push_to_hub( create_pr: bool = False, safe_serialization: bool = False, revision: str = None, + commit_description: str = None, **deprecated_kwargs, ) -> str: """ @@ -825,6 +828,8 @@ def push_to_hub( Whether or not to convert the model weights in safetensors format for safer serialization. revision (`str`, *optional*): Branch to push the uploaded files to. + commit_description (`str`, *optional*): + The description of the commit that will be created Examples: @@ -901,6 +906,7 @@ def push_to_hub( token=token, create_pr=create_pr, revision=revision, + commit_description=commit_description, ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index bccde5af5083..2a6246f8703b 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1119,6 +1119,23 @@ def test_push_to_hub(self): for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + def test_push_to_hub_with_description(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + model = BertModel(config) + COMMIT_DESCRIPTION = """ +The commit description supports markdown synthax see: +```python +>>> form transformers import AutoConfig +>>> config = AutoConfig.from_pretrained("bert-base-uncased") +``` +""" + commit_details = model.push_to_hub( + "test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION + ) + self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION) + @unittest.skip("This test is flaky") def test_push_to_hub_in_organization(self): config = BertConfig(