Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
cached_file,
extract_commit_hash,
find_adapter_config_file,
is_kenlm_available,
is_offline_mode,
Expand Down Expand Up @@ -698,11 +701,14 @@ def pipeline(
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token

code_revision = kwargs.pop("code_revision", None)
commit_hash = kwargs.pop("_commit_hash", None)

hub_kwargs = {
"revision": revision,
"token": token,
"trust_remote_code": trust_remote_code,
"_commit_hash": None,
"_commit_hash": commit_hash,
}

if task is None and model is None:
Expand All @@ -727,28 +733,53 @@ def pipeline(
if isinstance(model, Path):
model = str(model)

if commit_hash is None:
pretrained_model_name_or_path = None
if isinstance(config, str):
pretrained_model_name_or_path = config
elif config is None and isinstance(model, str):
pretrained_model_name_or_path = model
Comment on lines +737 to +741
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to get the repo. in different case before going to the next block of getting config, similar to the block in line 758


if not isinstance(config, PretrainedConfig) and pretrained_model_name_or_path is not None:
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
**hub_kwargs,
)
hub_kwargs["_commit_hash"] = extract_commit_hash(resolved_config_file, commit_hash)
else:
hub_kwargs["_commit_hash"] = getattr(config, "_commit_hash", None)

# Config is the primordial information item.
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
config = AutoConfig.from_pretrained(
config, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
)
hub_kwargs["_commit_hash"] = config._commit_hash
elif config is None and isinstance(model, str):
# Check for an adapter file in the model path if PEFT is available
if is_peft_available():
subfolder = hub_kwargs.get("subfolder", None)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subfolder was never in hub_kwargs here

# `find_adapter_config_file` doesn't accept `trust_remote_code`
_hub_kwargs = {k: v for k, v in hub_kwargs.items() if k != "trust_remote_code"}
maybe_adapter_path = find_adapter_config_file(
model,
revision=revision,
token=use_auth_token,
subfolder=subfolder,
token=hub_kwargs["token"],
revision=hub_kwargs["revision"],
_commit_hash=hub_kwargs["_commit_hash"],
)

if maybe_adapter_path is not None:
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
adapter_config = json.load(f)
model = adapter_config["base_model_name_or_path"]

config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
config = AutoConfig.from_pretrained(
model, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
)
hub_kwargs["_commit_hash"] = config._commit_hash

custom_tasks = {}
Expand All @@ -769,7 +800,7 @@ def pipeline(
"Inferring the task automatically requires to check the hub with a model_id defined as a `str`."
f"{model} is not a valid model_id."
)
task = get_task(model, use_auth_token)
task = get_task(model, token)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_auth_token is already deprecated.


# Retrieve the task
if task in custom_tasks:
Expand All @@ -784,7 +815,10 @@ def pipeline(
)
class_ref = targeted_task["impl"]
pipeline_class = get_class_from_dynamic_module(
class_ref, model, revision=revision, use_auth_token=use_auth_token
class_ref,
model,
code_revision=code_revision,
**hub_kwargs,
)
else:
normalized_task, targeted_task, task_options = check_task(task)
Expand Down