Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ def _upload_modified_files(
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
revision: str = None,
commit_description: str = None,
):
"""
Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
Expand Down Expand Up @@ -778,6 +779,7 @@ def _upload_modified_files(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
commit_description=commit_description,
token=token,
create_pr=create_pr,
revision=revision,
Expand All @@ -794,6 +796,7 @@ def push_to_hub(
create_pr: bool = False,
safe_serialization: bool = False,
revision: str = None,
commit_description: str = None,
**deprecated_kwargs,
) -> str:
"""
Expand Down Expand Up @@ -825,6 +828,8 @@ def push_to_hub(
Whether or not to convert the model weights in safetensors format for safer serialization.
revision (`str`, *optional*):
Branch to push the uploaded files to.
commit_description (`str`, *optional*):
The description of the commit that will be created

Examples:

Expand Down Expand Up @@ -901,6 +906,7 @@ def push_to_hub(
token=token,
create_pr=create_pr,
revision=revision,
commit_description=commit_description,
)


Expand Down
17 changes: 17 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,23 @@ def test_push_to_hub(self):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))

def test_push_to_hub_with_description(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
COMMIT_DESCRIPTION = """
The commit description supports markdown synthax see:
```python
>>> form transformers import AutoConfig
>>> config = AutoConfig.from_pretrained("bert-base-uncased")
```
"""
commit_details = model.push_to_hub(
"test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION
)
self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION)

@unittest.skip("This test is flaky")
def test_push_to_hub_in_organization(self):
config = BertConfig(
Expand Down