I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
# ruff: noqa: F401
from huggingface_hub.errors import (
BadRequestError,
CacheNotFound,
CorruptedCacheException,
DisabledRepoError,
EntryNotFoundError,
FileMetadataError,
GatedRepoError,
HfHubHTTPError,
HFValidationError,
LocalEntryNotFoundError,
LocalTokenNotFoundError,
NotASafetensorsRepoError,
OfflineModeIsEnabled,
RepositoryNotFoundError,
RevisionNotFoundError,
SafetensorsParsingError,
)
from . import tqdm as _tqdm # _tqdm is the module
from ._auth import get_stored_tokens, get_token
from ._cache_assets import cached_assets_path
from ._cache_manager import (
CachedFileInfo,
CachedRepoInfo,
CachedRevisionInfo,
DeleteCacheStrategy,
HFCacheInfo,
scan_cache_dir,
)
from ._chunk_utils import chunk_iterable
from ._datetime import parse_datetime
from ._experimental import experimental
from ._fixes import SoftTemporaryDirectory, WeakFileLock, yaml_dump
from ._git_credential import list_credential_helpers, set_git_credential, unset_git_credential
from ._headers import build_hf_headers, get_token_to_send
from ._hf_folder import HfFolder
from ._http import (
configure_http_backend,
fix_hf_endpoint_in_url,
get_session,
hf_raise_for_status,
http_backoff,
reset_sessions,
)
from ._pagination import paginate
from ._paths import DEFAULT_IGNORE_PATTERNS, FORBIDDEN_FOLDERS, filter_repo_objects
from ._runtime import (
dump_environment_info,
get_aiohttp_version,
get_fastai_version,
get_fastapi_version,
get_fastcore_version,
get_gradio_version,
get_graphviz_version,
get_hf_hub_version,
get_hf_transfer_version,
get_jinja_version,
get_numpy_version,
get_pillow_version,
get_pydantic_version,
get_pydot_version,
get_python_version,
get_tensorboard_version,
get_tf_version,
get_torch_version,
is_aiohttp_available,
is_colab_enterprise,
is_fastai_available,
is_fastapi_available,
is_fastcore_available,
is_google_colab,
is_gradio_available,
is_graphviz_available,
is_hf_transfer_available,
is_jinja_available,
is_notebook,
is_numpy_available,
is_package_available,
is_pillow_available,
is_pydantic_available,
is_pydot_available,
is_safetensors_available,
is_tensorboard_available,
is_tf_available,
is_torch_available,
)
from ._safetensors import SafetensorsFileMetadata, SafetensorsRepoMetadata, TensorInfo
from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess
from ._telemetry import send_telemetry
from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type
from ._validators import smoothly_deprecate_use_auth_token, validate_hf_hub_args, validate_repo_id
from .tqdm import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm, tqdm_stream_file

View File

@ -0,0 +1,214 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains an helper to get the token from machine (env variable, secret or config file)."""
import configparser
import logging
import os
import warnings
from pathlib import Path
from threading import Lock
from typing import Dict, Optional
from .. import constants
from ._runtime import is_colab_enterprise, is_google_colab
_IS_GOOGLE_COLAB_CHECKED = False
_GOOGLE_COLAB_SECRET_LOCK = Lock()
_GOOGLE_COLAB_SECRET: Optional[str] = None
logger = logging.getLogger(__name__)
def get_token() -> Optional[str]:
"""
Get token if user is logged in.
Note: in most cases, you should use [`huggingface_hub.utils.build_hf_headers`] instead. This method is only useful
if you want to retrieve the token for other purposes than sending an HTTP request.
Token is retrieved in priority from the `HF_TOKEN` environment variable. Otherwise, we read the token file located
in the Hugging Face home folder. Returns None if user is not logged in. To log in, use [`login`] or
`huggingface-cli login`.
Returns:
`str` or `None`: The token, `None` if it doesn't exist.
"""
return _get_token_from_google_colab() or _get_token_from_environment() or _get_token_from_file()
def _get_token_from_google_colab() -> Optional[str]:
"""Get token from Google Colab secrets vault using `google.colab.userdata.get(...)`.
Token is read from the vault only once per session and then stored in a global variable to avoid re-requesting
access to the vault.
"""
# If it's not a Google Colab or it's Colab Enterprise, fallback to environment variable or token file authentication
if not is_google_colab() or is_colab_enterprise():
return None
# `google.colab.userdata` is not thread-safe
# This can lead to a deadlock if multiple threads try to access it at the same time
# (typically when using `snapshot_download`)
# => use a lock
# See https://github.com/huggingface/huggingface_hub/issues/1952 for more details.
with _GOOGLE_COLAB_SECRET_LOCK:
global _GOOGLE_COLAB_SECRET
global _IS_GOOGLE_COLAB_CHECKED
if _IS_GOOGLE_COLAB_CHECKED: # request access only once
return _GOOGLE_COLAB_SECRET
try:
from google.colab import userdata # type: ignore
from google.colab.errors import Error as ColabError # type: ignore
except ImportError:
return None
try:
token = userdata.get("HF_TOKEN")
_GOOGLE_COLAB_SECRET = _clean_token(token)
except userdata.NotebookAccessError:
# Means the user has a secret call `HF_TOKEN` and got a popup "please grand access to HF_TOKEN" and refused it
# => warn user but ignore error => do not re-request access to user
warnings.warn(
"\nAccess to the secret `HF_TOKEN` has not been granted on this notebook."
"\nYou will not be requested again."
"\nPlease restart the session if you want to be prompted again."
)
_GOOGLE_COLAB_SECRET = None
except userdata.SecretNotFoundError:
# Means the user did not define a `HF_TOKEN` secret => warn
warnings.warn(
"\nThe secret `HF_TOKEN` does not exist in your Colab secrets."
"\nTo authenticate with the Hugging Face Hub, create a token in your settings tab "
"(https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session."
"\nYou will be able to reuse this secret in all of your notebooks."
"\nPlease note that authentication is recommended but still optional to access public models or datasets."
)
_GOOGLE_COLAB_SECRET = None
except ColabError as e:
# Something happen but we don't know what => recommend to open a GitHub issue
warnings.warn(
f"\nError while fetching `HF_TOKEN` secret value from your vault: '{str(e)}'."
"\nYou are not authenticated with the Hugging Face Hub in this notebook."
"\nIf the error persists, please let us know by opening an issue on GitHub "
"(https://github.com/huggingface/huggingface_hub/issues/new)."
)
_GOOGLE_COLAB_SECRET = None
_IS_GOOGLE_COLAB_CHECKED = True
return _GOOGLE_COLAB_SECRET
def _get_token_from_environment() -> Optional[str]:
# `HF_TOKEN` has priority (keep `HUGGING_FACE_HUB_TOKEN` for backward compatibility)
return _clean_token(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN"))
def _get_token_from_file() -> Optional[str]:
try:
return _clean_token(Path(constants.HF_TOKEN_PATH).read_text())
except FileNotFoundError:
return None
def get_stored_tokens() -> Dict[str, str]:
"""
Returns the parsed INI file containing the access tokens.
The file is located at `HF_STORED_TOKENS_PATH`, defaulting to `~/.cache/huggingface/stored_tokens`.
If the file does not exist, an empty dictionary is returned.
Returns: `Dict[str, str]`
Key is the token name and value is the token.
"""
tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
if not tokens_path.exists():
stored_tokens = {}
config = configparser.ConfigParser()
try:
config.read(tokens_path)
stored_tokens = {token_name: config.get(token_name, "hf_token") for token_name in config.sections()}
except configparser.Error as e:
logger.error(f"Error parsing stored tokens file: {e}")
stored_tokens = {}
return stored_tokens
def _save_stored_tokens(stored_tokens: Dict[str, str]) -> None:
"""
Saves the given configuration to the stored tokens file.
Args:
stored_tokens (`Dict[str, str]`):
The stored tokens to save. Key is the token name and value is the token.
"""
stored_tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
# Write the stored tokens into an INI file
config = configparser.ConfigParser()
for token_name in sorted(stored_tokens.keys()):
config.add_section(token_name)
config.set(token_name, "hf_token", stored_tokens[token_name])
stored_tokens_path.parent.mkdir(parents=True, exist_ok=True)
with stored_tokens_path.open("w") as config_file:
config.write(config_file)
def _get_token_by_name(token_name: str) -> Optional[str]:
"""
Get the token by name.
Args:
token_name (`str`):
The name of the token to get.
Returns:
`str` or `None`: The token, `None` if it doesn't exist.
"""
stored_tokens = get_stored_tokens()
if token_name not in stored_tokens:
return None
return _clean_token(stored_tokens[token_name])
def _save_token(token: str, token_name: str) -> None:
"""
Save the given token.
If the stored tokens file does not exist, it will be created.
Args:
token (`str`):
The token to save.
token_name (`str`):
The name of the token.
"""
tokens_path = Path(constants.HF_STORED_TOKENS_PATH)
stored_tokens = get_stored_tokens()
stored_tokens[token_name] = token
_save_stored_tokens(stored_tokens)
logger.info(f"The token `{token_name}` has been saved to {tokens_path}")
def _clean_token(token: Optional[str]) -> Optional[str]:
"""Clean token by removing trailing and leading spaces and newlines.
If token is an empty string, return None.
"""
if token is None:
return None
return token.replace("\r", "").replace("\n", "").strip() or None

View File

@ -0,0 +1,135 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Union
from ..constants import HF_ASSETS_CACHE
def cached_assets_path(
library_name: str,
namespace: str = "default",
subfolder: str = "default",
*,
assets_dir: Union[str, Path, None] = None,
):
"""Return a folder path to cache arbitrary files.
`huggingface_hub` provides a canonical folder path to store assets. This is the
recommended way to integrate cache in a downstream library as it will benefit from
the builtins tools to scan and delete the cache properly.
The distinction is made between files cached from the Hub and assets. Files from the
Hub are cached in a git-aware manner and entirely managed by `huggingface_hub`. See
[related documentation](https://huggingface.co/docs/huggingface_hub/how-to-cache).
All other files that a downstream library caches are considered to be "assets"
(files downloaded from external sources, extracted from a .tar archive, preprocessed
for training,...).
Once the folder path is generated, it is guaranteed to exist and to be a directory.
The path is based on 3 levels of depth: the library name, a namespace and a
subfolder. Those 3 levels grants flexibility while allowing `huggingface_hub` to
expect folders when scanning/deleting parts of the assets cache. Within a library,
it is expected that all namespaces share the same subset of subfolder names but this
is not a mandatory rule. The downstream library has then full control on which file
structure to adopt within its cache. Namespace and subfolder are optional (would
default to a `"default/"` subfolder) but library name is mandatory as we want every
downstream library to manage its own cache.
Expected tree:
```text
assets/
└── datasets/
│ ├── SQuAD/
│ │ ├── downloaded/
│ │ ├── extracted/
│ │ └── processed/
│ ├── Helsinki-NLP--tatoeba_mt/
│ ├── downloaded/
│ ├── extracted/
│ └── processed/
└── transformers/
├── default/
│ ├── something/
├── bert-base-cased/
│ ├── default/
│ └── training/
hub/
└── models--julien-c--EsperBERTo-small/
├── blobs/
│ ├── (...)
│ ├── (...)
├── refs/
│ └── (...)
└── [ 128] snapshots/
├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/
│ ├── (...)
└── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/
└── (...)
```
Args:
library_name (`str`):
Name of the library that will manage the cache folder. Example: `"dataset"`.
namespace (`str`, *optional*, defaults to "default"):
Namespace to which the data belongs. Example: `"SQuAD"`.
subfolder (`str`, *optional*, defaults to "default"):
Subfolder in which the data will be stored. Example: `extracted`.
assets_dir (`str`, `Path`, *optional*):
Path to the folder where assets are cached. This must not be the same folder
where Hub files are cached. Defaults to `HF_HOME / "assets"` if not provided.
Can also be set with `HF_ASSETS_CACHE` environment variable.
Returns:
Path to the cache folder (`Path`).
Example:
```py
>>> from huggingface_hub import cached_assets_path
>>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download")
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/download')
>>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="extracted")
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/extracted')
>>> cached_assets_path(library_name="datasets", namespace="Helsinki-NLP/tatoeba_mt")
PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/Helsinki-NLP--tatoeba_mt/default')
>>> cached_assets_path(library_name="datasets", assets_dir="/tmp/tmp123456")
PosixPath('/tmp/tmp123456/datasets/default/default')
```
"""
# Resolve assets_dir
if assets_dir is None:
assets_dir = HF_ASSETS_CACHE
assets_dir = Path(assets_dir).expanduser().resolve()
# Avoid names that could create path issues
for part in (" ", "/", "\\"):
library_name = library_name.replace(part, "--")
namespace = namespace.replace(part, "--")
subfolder = subfolder.replace(part, "--")
# Path to subfolder is created
path = assets_dir / library_name / namespace / subfolder
try:
path.mkdir(exist_ok=True, parents=True)
except (FileExistsError, NotADirectoryError):
raise ValueError(f"Corrupted assets folder: cannot create directory because of an existing file ({path}).")
# Return
return path

View File

@ -0,0 +1,896 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to manage the HF cache directory."""
import os
import shutil
import time
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, FrozenSet, List, Literal, Optional, Set, Union
from huggingface_hub.errors import CacheNotFound, CorruptedCacheException
from ..commands._cli_utils import tabulate
from ..constants import HF_HUB_CACHE
from . import logging
logger = logging.get_logger(__name__)
REPO_TYPE_T = Literal["model", "dataset", "space"]
# List of OS-created helper files that need to be ignored
FILES_TO_IGNORE = [".DS_Store"]
@dataclass(frozen=True)
class CachedFileInfo:
"""Frozen data structure holding information about a single cached file.
Args:
file_name (`str`):
Name of the file. Example: `config.json`.
file_path (`Path`):
Path of the file in the `snapshots` directory. The file path is a symlink
referring to a blob in the `blobs` folder.
blob_path (`Path`):
Path of the blob file. This is equivalent to `file_path.resolve()`.
size_on_disk (`int`):
Size of the blob file in bytes.
blob_last_accessed (`float`):
Timestamp of the last time the blob file has been accessed (from any
revision).
blob_last_modified (`float`):
Timestamp of the last time the blob file has been modified/created.
<Tip warning={true}>
`blob_last_accessed` and `blob_last_modified` reliability can depend on the OS you
are using. See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result)
for more details.
</Tip>
"""
file_name: str
file_path: Path
blob_path: Path
size_on_disk: int
blob_last_accessed: float
blob_last_modified: float
@property
def blob_last_accessed_str(self) -> str:
"""
(property) Timestamp of the last time the blob file has been accessed (from any
revision), returned as a human-readable string.
Example: "2 weeks ago".
"""
return _format_timesince(self.blob_last_accessed)
@property
def blob_last_modified_str(self) -> str:
"""
(property) Timestamp of the last time the blob file has been modified, returned
as a human-readable string.
Example: "2 weeks ago".
"""
return _format_timesince(self.blob_last_modified)
@property
def size_on_disk_str(self) -> str:
"""
(property) Size of the blob file as a human-readable string.
Example: "42.2K".
"""
return _format_size(self.size_on_disk)
@dataclass(frozen=True)
class CachedRevisionInfo:
"""Frozen data structure holding information about a revision.
A revision correspond to a folder in the `snapshots` folder and is populated with
the exact tree structure as the repo on the Hub but contains only symlinks. A
revision can be either referenced by 1 or more `refs` or be "detached" (no refs).
Args:
commit_hash (`str`):
Hash of the revision (unique).
Example: `"9338f7b671827df886678df2bdd7cc7b4f36dffd"`.
snapshot_path (`Path`):
Path to the revision directory in the `snapshots` folder. It contains the
exact tree structure as the repo on the Hub.
files: (`FrozenSet[CachedFileInfo]`):
Set of [`~CachedFileInfo`] describing all files contained in the snapshot.
refs (`FrozenSet[str]`):
Set of `refs` pointing to this revision. If the revision has no `refs`, it
is considered detached.
Example: `{"main", "2.4.0"}` or `{"refs/pr/1"}`.
size_on_disk (`int`):
Sum of the blob file sizes that are symlink-ed by the revision.
last_modified (`float`):
Timestamp of the last time the revision has been created/modified.
<Tip warning={true}>
`last_accessed` cannot be determined correctly on a single revision as blob files
are shared across revisions.
</Tip>
<Tip warning={true}>
`size_on_disk` is not necessarily the sum of all file sizes because of possible
duplicated files. Besides, only blobs are taken into account, not the (negligible)
size of folders and symlinks.
</Tip>
"""
commit_hash: str
snapshot_path: Path
size_on_disk: int
files: FrozenSet[CachedFileInfo]
refs: FrozenSet[str]
last_modified: float
@property
def last_modified_str(self) -> str:
"""
(property) Timestamp of the last time the revision has been modified, returned
as a human-readable string.
Example: "2 weeks ago".
"""
return _format_timesince(self.last_modified)
@property
def size_on_disk_str(self) -> str:
"""
(property) Sum of the blob file sizes as a human-readable string.
Example: "42.2K".
"""
return _format_size(self.size_on_disk)
@property
def nb_files(self) -> int:
"""
(property) Total number of files in the revision.
"""
return len(self.files)
@dataclass(frozen=True)
class CachedRepoInfo:
"""Frozen data structure holding information about a cached repository.
Args:
repo_id (`str`):
Repo id of the repo on the Hub. Example: `"google/fleurs"`.
repo_type (`Literal["dataset", "model", "space"]`):
Type of the cached repo.
repo_path (`Path`):
Local path to the cached repo.
size_on_disk (`int`):
Sum of the blob file sizes in the cached repo.
nb_files (`int`):
Total number of blob files in the cached repo.
revisions (`FrozenSet[CachedRevisionInfo]`):
Set of [`~CachedRevisionInfo`] describing all revisions cached in the repo.
last_accessed (`float`):
Timestamp of the last time a blob file of the repo has been accessed.
last_modified (`float`):
Timestamp of the last time a blob file of the repo has been modified/created.
<Tip warning={true}>
`size_on_disk` is not necessarily the sum of all revisions sizes because of
duplicated files. Besides, only blobs are taken into account, not the (negligible)
size of folders and symlinks.
</Tip>
<Tip warning={true}>
`last_accessed` and `last_modified` reliability can depend on the OS you are using.
See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result)
for more details.
</Tip>
"""
repo_id: str
repo_type: REPO_TYPE_T
repo_path: Path
size_on_disk: int
nb_files: int
revisions: FrozenSet[CachedRevisionInfo]
last_accessed: float
last_modified: float
@property
def last_accessed_str(self) -> str:
"""
(property) Last time a blob file of the repo has been accessed, returned as a
human-readable string.
Example: "2 weeks ago".
"""
return _format_timesince(self.last_accessed)
@property
def last_modified_str(self) -> str:
"""
(property) Last time a blob file of the repo has been modified, returned as a
human-readable string.
Example: "2 weeks ago".
"""
return _format_timesince(self.last_modified)
@property
def size_on_disk_str(self) -> str:
"""
(property) Sum of the blob file sizes as a human-readable string.
Example: "42.2K".
"""
return _format_size(self.size_on_disk)
@property
def refs(self) -> Dict[str, CachedRevisionInfo]:
"""
(property) Mapping between `refs` and revision data structures.
"""
return {ref: revision for revision in self.revisions for ref in revision.refs}
@dataclass(frozen=True)
class DeleteCacheStrategy:
"""Frozen data structure holding the strategy to delete cached revisions.
This object is not meant to be instantiated programmatically but to be returned by
[`~utils.HFCacheInfo.delete_revisions`]. See documentation for usage example.
Args:
expected_freed_size (`float`):
Expected freed size once strategy is executed.
blobs (`FrozenSet[Path]`):
Set of blob file paths to be deleted.
refs (`FrozenSet[Path]`):
Set of reference file paths to be deleted.
repos (`FrozenSet[Path]`):
Set of entire repo paths to be deleted.
snapshots (`FrozenSet[Path]`):
Set of snapshots to be deleted (directory of symlinks).
"""
expected_freed_size: int
blobs: FrozenSet[Path]
refs: FrozenSet[Path]
repos: FrozenSet[Path]
snapshots: FrozenSet[Path]
@property
def expected_freed_size_str(self) -> str:
"""
(property) Expected size that will be freed as a human-readable string.
Example: "42.2K".
"""
return _format_size(self.expected_freed_size)
def execute(self) -> None:
"""Execute the defined strategy.
<Tip warning={true}>
If this method is interrupted, the cache might get corrupted. Deletion order is
implemented so that references and symlinks are deleted before the actual blob
files.
</Tip>
<Tip warning={true}>
This method is irreversible. If executed, cached files are erased and must be
downloaded again.
</Tip>
"""
# Deletion order matters. Blobs are deleted in last so that the user can't end
# up in a state where a `ref`` refers to a missing snapshot or a snapshot
# symlink refers to a deleted blob.
# Delete entire repos
for path in self.repos:
_try_delete_path(path, path_type="repo")
# Delete snapshot directories
for path in self.snapshots:
_try_delete_path(path, path_type="snapshot")
# Delete refs files
for path in self.refs:
_try_delete_path(path, path_type="ref")
# Delete blob files
for path in self.blobs:
_try_delete_path(path, path_type="blob")
logger.info(f"Cache deletion done. Saved {self.expected_freed_size_str}.")
@dataclass(frozen=True)
class HFCacheInfo:
"""Frozen data structure holding information about the entire cache-system.
This data structure is returned by [`scan_cache_dir`] and is immutable.
Args:
size_on_disk (`int`):
Sum of all valid repo sizes in the cache-system.
repos (`FrozenSet[CachedRepoInfo]`):
Set of [`~CachedRepoInfo`] describing all valid cached repos found on the
cache-system while scanning.
warnings (`List[CorruptedCacheException]`):
List of [`~CorruptedCacheException`] that occurred while scanning the cache.
Those exceptions are captured so that the scan can continue. Corrupted repos
are skipped from the scan.
<Tip warning={true}>
Here `size_on_disk` is equal to the sum of all repo sizes (only blobs). However if
some cached repos are corrupted, their sizes are not taken into account.
</Tip>
"""
size_on_disk: int
repos: FrozenSet[CachedRepoInfo]
warnings: List[CorruptedCacheException]
@property
def size_on_disk_str(self) -> str:
"""
(property) Sum of all valid repo sizes in the cache-system as a human-readable
string.
Example: "42.2K".
"""
return _format_size(self.size_on_disk)
def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy:
"""Prepare the strategy to delete one or more revisions cached locally.
Input revisions can be any revision hash. If a revision hash is not found in the
local cache, a warning is thrown but no error is raised. Revisions can be from
different cached repos since hashes are unique across repos,
Examples:
```py
>>> from huggingface_hub import scan_cache_dir
>>> cache_info = scan_cache_dir()
>>> delete_strategy = cache_info.delete_revisions(
... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa"
... )
>>> print(f"Will free {delete_strategy.expected_freed_size_str}.")
Will free 7.9K.
>>> delete_strategy.execute()
Cache deletion done. Saved 7.9K.
```
```py
>>> from huggingface_hub import scan_cache_dir
>>> scan_cache_dir().delete_revisions(
... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa",
... "e2983b237dccf3ab4937c97fa717319a9ca1a96d",
... "6c0e6080953db56375760c0471a8c5f2929baf11",
... ).execute()
Cache deletion done. Saved 8.6G.
```
<Tip warning={true}>
`delete_revisions` returns a [`~utils.DeleteCacheStrategy`] object that needs to
be executed. The [`~utils.DeleteCacheStrategy`] is not meant to be modified but
allows having a dry run before actually executing the deletion.
</Tip>
"""
hashes_to_delete: Set[str] = set(revisions)
repos_with_revisions: Dict[CachedRepoInfo, Set[CachedRevisionInfo]] = defaultdict(set)
for repo in self.repos:
for revision in repo.revisions:
if revision.commit_hash in hashes_to_delete:
repos_with_revisions[repo].add(revision)
hashes_to_delete.remove(revision.commit_hash)
if len(hashes_to_delete) > 0:
logger.warning(f"Revision(s) not found - cannot delete them: {', '.join(hashes_to_delete)}")
delete_strategy_blobs: Set[Path] = set()
delete_strategy_refs: Set[Path] = set()
delete_strategy_repos: Set[Path] = set()
delete_strategy_snapshots: Set[Path] = set()
delete_strategy_expected_freed_size = 0
for affected_repo, revisions_to_delete in repos_with_revisions.items():
other_revisions = affected_repo.revisions - revisions_to_delete
# If no other revisions, it means all revisions are deleted
# -> delete the entire cached repo
if len(other_revisions) == 0:
delete_strategy_repos.add(affected_repo.repo_path)
delete_strategy_expected_freed_size += affected_repo.size_on_disk
continue
# Some revisions of the repo will be deleted but not all. We need to filter
# which blob files will not be linked anymore.
for revision_to_delete in revisions_to_delete:
# Snapshot dir
delete_strategy_snapshots.add(revision_to_delete.snapshot_path)
# Refs dir
for ref in revision_to_delete.refs:
delete_strategy_refs.add(affected_repo.repo_path / "refs" / ref)
# Blobs dir
for file in revision_to_delete.files:
if file.blob_path not in delete_strategy_blobs:
is_file_alone = True
for revision in other_revisions:
for rev_file in revision.files:
if file.blob_path == rev_file.blob_path:
is_file_alone = False
break
if not is_file_alone:
break
# Blob file not referenced by remaining revisions -> delete
if is_file_alone:
delete_strategy_blobs.add(file.blob_path)
delete_strategy_expected_freed_size += file.size_on_disk
# Return the strategy instead of executing it.
return DeleteCacheStrategy(
blobs=frozenset(delete_strategy_blobs),
refs=frozenset(delete_strategy_refs),
repos=frozenset(delete_strategy_repos),
snapshots=frozenset(delete_strategy_snapshots),
expected_freed_size=delete_strategy_expected_freed_size,
)
def export_as_table(self, *, verbosity: int = 0) -> str:
"""Generate a table from the [`HFCacheInfo`] object.
Pass `verbosity=0` to get a table with a single row per repo, with columns
"repo_id", "repo_type", "size_on_disk", "nb_files", "last_accessed", "last_modified", "refs", "local_path".
Pass `verbosity=1` to get a table with a row per repo and revision (thus multiple rows can appear for a single repo), with columns
"repo_id", "repo_type", "revision", "size_on_disk", "nb_files", "last_modified", "refs", "local_path".
Example:
```py
>>> from huggingface_hub.utils import scan_cache_dir
>>> hf_cache_info = scan_cache_dir()
HFCacheInfo(...)
>>> print(hf_cache_info.export_as_table())
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
--------------------------------------------------- --------- ------------ -------- ------------- ------------- ---- --------------------------------------------------------------------------------------------------
roberta-base model 2.7M 5 1 day ago 1 week ago main ~/.cache/huggingface/hub/models--roberta-base
suno/bark model 8.8K 1 1 week ago 1 week ago main ~/.cache/huggingface/hub/models--suno--bark
t5-base model 893.8M 4 4 days ago 7 months ago main ~/.cache/huggingface/hub/models--t5-base
t5-large model 3.0G 4 5 weeks ago 5 months ago main ~/.cache/huggingface/hub/models--t5-large
>>> print(hf_cache_info.export_as_table(verbosity=1))
REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH
--------------------------------------------------- --------- ---------------------------------------- ------------ -------- ------------- ---- -----------------------------------------------------------------------------------------------------------------------------------------------------
roberta-base model e2da8e2f811d1448a5b465c236feacd80ffbac7b 2.7M 5 1 week ago main ~/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b
suno/bark model 70a8a7d34168586dc5d028fa9666aceade177992 8.8K 1 1 week ago main ~/.cache/huggingface/hub/models--suno--bark/snapshots/70a8a7d34168586dc5d028fa9666aceade177992
t5-base model a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 893.8M 4 7 months ago main ~/.cache/huggingface/hub/models--t5-base/snapshots/a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1
t5-large model 150ebc2c4b72291e770f58e6057481c8d2ed331a 3.0G 4 5 months ago main ~/.cache/huggingface/hub/models--t5-large/snapshots/150ebc2c4b72291e770f58e6057481c8d2ed331a
```
Args:
verbosity (`int`, *optional*):
The verbosity level. Defaults to 0.
Returns:
`str`: The table as a string.
"""
if verbosity == 0:
return tabulate(
rows=[
[
repo.repo_id,
repo.repo_type,
"{:>12}".format(repo.size_on_disk_str),
repo.nb_files,
repo.last_accessed_str,
repo.last_modified_str,
", ".join(sorted(repo.refs)),
str(repo.repo_path),
]
for repo in sorted(self.repos, key=lambda repo: repo.repo_path)
],
headers=[
"REPO ID",
"REPO TYPE",
"SIZE ON DISK",
"NB FILES",
"LAST_ACCESSED",
"LAST_MODIFIED",
"REFS",
"LOCAL PATH",
],
)
else:
return tabulate(
rows=[
[
repo.repo_id,
repo.repo_type,
revision.commit_hash,
"{:>12}".format(revision.size_on_disk_str),
revision.nb_files,
revision.last_modified_str,
", ".join(sorted(revision.refs)),
str(revision.snapshot_path),
]
for repo in sorted(self.repos, key=lambda repo: repo.repo_path)
for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash)
],
headers=[
"REPO ID",
"REPO TYPE",
"REVISION",
"SIZE ON DISK",
"NB FILES",
"LAST_MODIFIED",
"REFS",
"LOCAL PATH",
],
)
def scan_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> HFCacheInfo:
"""Scan the entire HF cache-system and return a [`~HFCacheInfo`] structure.
Use `scan_cache_dir` in order to programmatically scan your cache-system. The cache
will be scanned repo by repo. If a repo is corrupted, a [`~CorruptedCacheException`]
will be thrown internally but captured and returned in the [`~HFCacheInfo`]
structure. Only valid repos get a proper report.
```py
>>> from huggingface_hub import scan_cache_dir
>>> hf_cache_info = scan_cache_dir()
HFCacheInfo(
size_on_disk=3398085269,
repos=frozenset({
CachedRepoInfo(
repo_id='t5-small',
repo_type='model',
repo_path=PosixPath(...),
size_on_disk=970726914,
nb_files=11,
revisions=frozenset({
CachedRevisionInfo(
commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5',
size_on_disk=970726339,
snapshot_path=PosixPath(...),
files=frozenset({
CachedFileInfo(
file_name='config.json',
size_on_disk=1197
file_path=PosixPath(...),
blob_path=PosixPath(...),
),
CachedFileInfo(...),
...
}),
),
CachedRevisionInfo(...),
...
}),
),
CachedRepoInfo(...),
...
}),
warnings=[
CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."),
CorruptedCacheException(...),
...
],
)
```
You can also print a detailed report directly from the `huggingface-cli` using:
```text
> huggingface-cli scan-cache
REPO ID REPO TYPE SIZE ON DISK NB FILES REFS LOCAL PATH
--------------------------- --------- ------------ -------- ------------------- -------------------------------------------------------------------------
glue dataset 116.3K 15 1.17.0, main, 2.4.0 /Users/lucain/.cache/huggingface/hub/datasets--glue
google/fleurs dataset 64.9M 6 main, refs/pr/1 /Users/lucain/.cache/huggingface/hub/datasets--google--fleurs
Jean-Baptiste/camembert-ner model 441.0M 7 main /Users/lucain/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner
bert-base-cased model 1.9G 13 main /Users/lucain/.cache/huggingface/hub/models--bert-base-cased
t5-base model 10.1K 3 main /Users/lucain/.cache/huggingface/hub/models--t5-base
t5-small model 970.7M 11 refs/pr/1, main /Users/lucain/.cache/huggingface/hub/models--t5-small
Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G.
Got 1 warning(s) while scanning. Use -vvv to print details.
```
Args:
cache_dir (`str` or `Path`, `optional`):
Cache directory to cache. Defaults to the default HF cache directory.
<Tip warning={true}>
Raises:
`CacheNotFound`
If the cache directory does not exist.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the cache directory is a file, instead of a directory.
</Tip>
Returns: a [`~HFCacheInfo`] object.
"""
if cache_dir is None:
cache_dir = HF_HUB_CACHE
cache_dir = Path(cache_dir).expanduser().resolve()
if not cache_dir.exists():
raise CacheNotFound(
f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.",
cache_dir=cache_dir,
)
if cache_dir.is_file():
raise ValueError(
f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable."
)
repos: Set[CachedRepoInfo] = set()
warnings: List[CorruptedCacheException] = []
for repo_path in cache_dir.iterdir():
if repo_path.name == ".locks": # skip './.locks/' folder
continue
try:
repos.add(_scan_cached_repo(repo_path))
except CorruptedCacheException as e:
warnings.append(e)
return HFCacheInfo(
repos=frozenset(repos),
size_on_disk=sum(repo.size_on_disk for repo in repos),
warnings=warnings,
)
def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo:
"""Scan a single cache repo and return information about it.
Any unexpected behavior will raise a [`~CorruptedCacheException`].
"""
if not repo_path.is_dir():
raise CorruptedCacheException(f"Repo path is not a directory: {repo_path}")
if "--" not in repo_path.name:
raise CorruptedCacheException(f"Repo path is not a valid HuggingFace cache directory: {repo_path}")
repo_type, repo_id = repo_path.name.split("--", maxsplit=1)
repo_type = repo_type[:-1] # "models" -> "model"
repo_id = repo_id.replace("--", "/") # google/fleurs -> "google/fleurs"
if repo_type not in {"dataset", "model", "space"}:
raise CorruptedCacheException(
f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}` ({repo_path})."
)
blob_stats: Dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats
snapshots_path = repo_path / "snapshots"
refs_path = repo_path / "refs"
if not snapshots_path.exists() or not snapshots_path.is_dir():
raise CorruptedCacheException(f"Snapshots dir doesn't exist in cached repo: {snapshots_path}")
# Scan over `refs` directory
# key is revision hash, value is set of refs
refs_by_hash: Dict[str, Set[str]] = defaultdict(set)
if refs_path.exists():
# Example of `refs` directory
# ── refs
# ├── main
# └── refs
# └── pr
# └── 1
if refs_path.is_file():
raise CorruptedCacheException(f"Refs directory cannot be a file: {refs_path}")
for ref_path in refs_path.glob("**/*"):
# glob("**/*") iterates over all files and directories -> skip directories
if ref_path.is_dir():
continue
ref_name = str(ref_path.relative_to(refs_path))
with ref_path.open() as f:
commit_hash = f.read()
refs_by_hash[commit_hash].add(ref_name)
# Scan snapshots directory
cached_revisions: Set[CachedRevisionInfo] = set()
for revision_path in snapshots_path.iterdir():
# Ignore OS-created helper files
if revision_path.name in FILES_TO_IGNORE:
continue
if revision_path.is_file():
raise CorruptedCacheException(f"Snapshots folder corrupted. Found a file: {revision_path}")
cached_files = set()
for file_path in revision_path.glob("**/*"):
# glob("**/*") iterates over all files and directories -> skip directories
if file_path.is_dir():
continue
blob_path = Path(file_path).resolve()
if not blob_path.exists():
raise CorruptedCacheException(f"Blob missing (broken symlink): {blob_path}")
if blob_path not in blob_stats:
blob_stats[blob_path] = blob_path.stat()
cached_files.add(
CachedFileInfo(
file_name=file_path.name,
file_path=file_path,
size_on_disk=blob_stats[blob_path].st_size,
blob_path=blob_path,
blob_last_accessed=blob_stats[blob_path].st_atime,
blob_last_modified=blob_stats[blob_path].st_mtime,
)
)
# Last modified is either the last modified blob file or the revision folder
# itself if it is empty
if len(cached_files) > 0:
revision_last_modified = max(blob_stats[file.blob_path].st_mtime for file in cached_files)
else:
revision_last_modified = revision_path.stat().st_mtime
cached_revisions.add(
CachedRevisionInfo(
commit_hash=revision_path.name,
files=frozenset(cached_files),
refs=frozenset(refs_by_hash.pop(revision_path.name, set())),
size_on_disk=sum(
blob_stats[blob_path].st_size for blob_path in set(file.blob_path for file in cached_files)
),
snapshot_path=revision_path,
last_modified=revision_last_modified,
)
)
# Check that all refs referred to an existing revision
if len(refs_by_hash) > 0:
raise CorruptedCacheException(
f"Reference(s) refer to missing commit hashes: {dict(refs_by_hash)} ({repo_path})."
)
# Last modified is either the last modified blob file or the repo folder itself if
# no blob files has been found. Same for last accessed.
if len(blob_stats) > 0:
repo_last_accessed = max(stat.st_atime for stat in blob_stats.values())
repo_last_modified = max(stat.st_mtime for stat in blob_stats.values())
else:
repo_stats = repo_path.stat()
repo_last_accessed = repo_stats.st_atime
repo_last_modified = repo_stats.st_mtime
# Build and return frozen structure
return CachedRepoInfo(
nb_files=len(blob_stats),
repo_id=repo_id,
repo_path=repo_path,
repo_type=repo_type, # type: ignore
revisions=frozenset(cached_revisions),
size_on_disk=sum(stat.st_size for stat in blob_stats.values()),
last_accessed=repo_last_accessed,
last_modified=repo_last_modified,
)
def _format_size(num: int) -> str:
"""Format size in bytes into a human-readable string.
Taken from https://stackoverflow.com/a/1094933
"""
num_f = float(num)
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
if abs(num_f) < 1000.0:
return f"{num_f:3.1f}{unit}"
num_f /= 1000.0
return f"{num_f:.1f}Y"
_TIMESINCE_CHUNKS = (
# Label, divider, max value
("second", 1, 60),
("minute", 60, 60),
("hour", 60 * 60, 24),
("day", 60 * 60 * 24, 6),
("week", 60 * 60 * 24 * 7, 6),
("month", 60 * 60 * 24 * 30, 11),
("year", 60 * 60 * 24 * 365, None),
)
def _format_timesince(ts: float) -> str:
"""Format timestamp in seconds into a human-readable string, relative to now.
Vaguely inspired by Django's `timesince` formatter.
"""
delta = time.time() - ts
if delta < 20:
return "a few seconds ago"
for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007
value = round(delta / divider)
if max_value is not None and value <= max_value:
break
return f"{value} {label}{'s' if value > 1 else ''} ago"
def _try_delete_path(path: Path, path_type: str) -> None:
"""Try to delete a local file or folder.
If the path does not exists, error is logged as a warning and then ignored.
Args:
path (`Path`)
Path to delete. Can be a file or a folder.
path_type (`str`)
What path are we deleting ? Only for logging purposes. Example: "snapshot".
"""
logger.info(f"Delete {path_type}: {path}")
try:
if path.is_file():
os.remove(path)
else:
shutil.rmtree(path)
except FileNotFoundError:
logger.warning(f"Couldn't delete {path_type}: file not found ({path})", exc_info=True)
except PermissionError:
logger.warning(f"Couldn't delete {path_type}: permission denied ({path})", exc_info=True)

View File

@ -0,0 +1,65 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a utility to iterate by chunks over an iterator."""
import itertools
from typing import Iterable, TypeVar
T = TypeVar("T")
def chunk_iterable(iterable: Iterable[T], chunk_size: int) -> Iterable[Iterable[T]]:
"""Iterates over an iterator chunk by chunk.
Taken from https://stackoverflow.com/a/8998040.
See also https://github.com/huggingface/huggingface_hub/pull/920#discussion_r938793088.
Args:
iterable (`Iterable`):
The iterable on which we want to iterate.
chunk_size (`int`):
Size of the chunks. Must be a strictly positive integer (e.g. >0).
Example:
```python
>>> from huggingface_hub.utils import chunk_iterable
>>> for items in chunk_iterable(range(17), chunk_size=8):
... print(items)
# [0, 1, 2, 3, 4, 5, 6, 7]
# [8, 9, 10, 11, 12, 13, 14, 15]
# [16] # smaller last chunk
```
Raises:
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If `chunk_size` <= 0.
<Tip warning={true}>
The last chunk can be smaller than `chunk_size`.
</Tip>
"""
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("`chunk_size` must be a strictly positive integer (>0).")
iterator = iter(iterable)
while True:
try:
next_item = next(iterator)
except StopIteration:
return
yield itertools.chain((next_item,), itertools.islice(iterator, chunk_size - 1))

View File

@ -0,0 +1,62 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle datetimes in Huggingface Hub."""
from datetime import datetime, timezone
def parse_datetime(date_string: str) -> datetime:
"""
Parses a date_string returned from the server to a datetime object.
This parser is a weak-parser is the sense that it handles only a single format of
date_string. It is expected that the server format will never change. The
implementation depends only on the standard lib to avoid an external dependency
(python-dateutil). See full discussion about this decision on PR:
https://github.com/huggingface/huggingface_hub/pull/999.
Example:
```py
> parse_datetime('2022-08-19T07:19:38.123Z')
datetime.datetime(2022, 8, 19, 7, 19, 38, 123000, tzinfo=timezone.utc)
```
Args:
date_string (`str`):
A string representing a datetime returned by the Hub server.
String is expected to follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern.
Returns:
A python datetime object.
Raises:
:class:`ValueError`:
If `date_string` cannot be parsed.
"""
try:
# Datetime ending with a Z means "UTC". We parse the date and then explicitly
# set the timezone to UTC.
# See https://en.wikipedia.org/wiki/ISO_8601#Coordinated_Universal_Time_(UTC)
# Taken from https://stackoverflow.com/a/3168394.
if len(date_string) == 30:
# Means timezoned-timestamp with nanoseconds precision. We need to truncate the last 3 digits.
date_string = date_string[:-4] + "Z"
dt = datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ")
return dt.replace(tzinfo=timezone.utc) # Set explicit timezone
except ValueError as e:
raise ValueError(
f"Cannot parse '{date_string}' as a datetime. Date string is expected to"
" follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern."
) from e

View File

@ -0,0 +1,136 @@
import warnings
from functools import wraps
from inspect import Parameter, signature
from typing import Iterable, Optional
def _deprecate_positional_args(*, version: str):
"""Decorator for methods that issues warnings for positional arguments.
Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument.
Args:
version (`str`):
The version when positional arguments will result in error.
"""
def _inner_deprecate_positional_args(f):
sig = signature(f)
kwonly_args = []
all_args = []
for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)
@wraps(f)
def inner_f(*args, **kwargs):
extra_args = len(args) - len(all_args)
if extra_args <= 0:
return f(*args, **kwargs)
# extra_args > 0
args_msg = [
f"{name}='{arg}'" if isinstance(arg, str) else f"{name}={arg}"
for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:])
]
args_msg = ", ".join(args_msg)
warnings.warn(
f"Deprecated positional argument(s) used in '{f.__name__}': pass"
f" {args_msg} as keyword args. From version {version} passing these"
" as positional arguments will result in an error,",
FutureWarning,
)
kwargs.update(zip(sig.parameters, args))
return f(**kwargs)
return inner_f
return _inner_deprecate_positional_args
def _deprecate_arguments(
*,
version: str,
deprecated_args: Iterable[str],
custom_message: Optional[str] = None,
):
"""Decorator to issue warnings when using deprecated arguments.
TODO: could be useful to be able to set a custom error message.
Args:
version (`str`):
The version when deprecated arguments will result in error.
deprecated_args (`List[str]`):
List of the arguments to be deprecated.
custom_message (`str`, *optional*):
Warning message that is raised. If not passed, a default warning message
will be created.
"""
def _inner_deprecate_positional_args(f):
sig = signature(f)
@wraps(f)
def inner_f(*args, **kwargs):
# Check for used deprecated arguments
used_deprecated_args = []
for _, parameter in zip(args, sig.parameters.values()):
if parameter.name in deprecated_args:
used_deprecated_args.append(parameter.name)
for kwarg_name, kwarg_value in kwargs.items():
if (
# If argument is deprecated but still used
kwarg_name in deprecated_args
# And then the value is not the default value
and kwarg_value != sig.parameters[kwarg_name].default
):
used_deprecated_args.append(kwarg_name)
# Warn and proceed
if len(used_deprecated_args) > 0:
message = (
f"Deprecated argument(s) used in '{f.__name__}':"
f" {', '.join(used_deprecated_args)}. Will not be supported from"
f" version '{version}'."
)
if custom_message is not None:
message += "\n\n" + custom_message
warnings.warn(message, FutureWarning)
return f(*args, **kwargs)
return inner_f
return _inner_deprecate_positional_args
def _deprecate_method(*, version: str, message: Optional[str] = None):
"""Decorator to issue warnings when using a deprecated method.
Args:
version (`str`):
The version when deprecated arguments will result in error.
message (`str`, *optional*):
Warning message that is raised. If not passed, a default warning message
will be created.
"""
def _inner_deprecate_method(f):
name = f.__name__
if name == "__init__":
name = f.__qualname__.split(".")[0] # class name instead of method name
@wraps(f)
def inner_f(*args, **kwargs):
warning_message = (
f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'."
)
if message is not None:
warning_message += " " + message
warnings.warn(warning_message, FutureWarning)
return f(*args, **kwargs)
return inner_f
return _inner_deprecate_method

View File

@ -0,0 +1,66 @@
# coding=utf-8
# Copyright 2023-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to flag a feature as "experimental" in Huggingface Hub."""
import warnings
from functools import wraps
from typing import Callable
from .. import constants
def experimental(fn: Callable) -> Callable:
"""Decorator to flag a feature as experimental.
An experimental feature trigger a warning when used as it might be subject to breaking changes in the future.
Warnings can be disabled by setting the environment variable `HF_EXPERIMENTAL_WARNING` to `0`.
Args:
fn (`Callable`):
The function to flag as experimental.
Returns:
`Callable`: The decorated function.
Example:
```python
>>> from huggingface_hub.utils import experimental
>>> @experimental
... def my_function():
... print("Hello world!")
>>> my_function()
UserWarning: 'my_function' is experimental and might be subject to breaking changes in the future. You can disable
this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable.
Hello world!
```
"""
# For classes, put the "experimental" around the "__new__" method => __new__ will be removed in warning message
name = fn.__qualname__[: -len(".__new__")] if fn.__qualname__.endswith(".__new__") else fn.__qualname__
@wraps(fn)
def _inner_fn(*args, **kwargs):
if not constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING:
warnings.warn(
f"'{name}' is experimental and might be subject to breaking changes in the future."
" You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment"
" variable.",
UserWarning,
)
return fn(*args, **kwargs)
return _inner_fn

View File

@ -0,0 +1,121 @@
# JSONDecodeError was introduced in requests=2.27 released in 2022.
# This allows us to support older requests for users
# More information: https://github.com/psf/requests/pull/5856
try:
from requests import JSONDecodeError # type: ignore # noqa: F401
except ImportError:
try:
from simplejson import JSONDecodeError # type: ignore # noqa: F401
except ImportError:
from json import JSONDecodeError # type: ignore # noqa: F401
import contextlib
import os
import shutil
import stat
import tempfile
from functools import partial
from pathlib import Path
from typing import Callable, Generator, Optional, Union
import yaml
from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout
from .. import constants
from . import logging
logger = logging.get_logger(__name__)
# Wrap `yaml.dump` to set `allow_unicode=True` by default.
#
# Example:
# ```py
# >>> yaml.dump({"emoji": "👀", "some unicode": "日本か"})
# 'emoji: "\\U0001F440"\nsome unicode: "\\u65E5\\u672C\\u304B"\n'
#
# >>> yaml_dump({"emoji": "👀", "some unicode": "日本か"})
# 'emoji: "👀"\nsome unicode: "日本か"\n'
# ```
yaml_dump: Callable[..., str] = partial(yaml.dump, stream=None, allow_unicode=True) # type: ignore
@contextlib.contextmanager
def SoftTemporaryDirectory(
suffix: Optional[str] = None,
prefix: Optional[str] = None,
dir: Optional[Union[Path, str]] = None,
**kwargs,
) -> Generator[Path, None, None]:
"""
Context manager to create a temporary directory and safely delete it.
If tmp directory cannot be deleted normally, we set the WRITE permission and retry.
If cleanup still fails, we give up but don't raise an exception. This is equivalent
to `tempfile.TemporaryDirectory(..., ignore_cleanup_errors=True)` introduced in
Python 3.10.
See https://www.scivision.dev/python-tempfile-permission-error-windows/.
"""
tmpdir = tempfile.TemporaryDirectory(prefix=prefix, suffix=suffix, dir=dir, **kwargs)
yield Path(tmpdir.name).resolve()
try:
# First once with normal cleanup
shutil.rmtree(tmpdir.name)
except Exception:
# If failed, try to set write permission and retry
try:
shutil.rmtree(tmpdir.name, onerror=_set_write_permission_and_retry)
except Exception:
pass
# And finally, cleanup the tmpdir.
# If it fails again, give up but do not throw error
try:
tmpdir.cleanup()
except Exception:
pass
def _set_write_permission_and_retry(func, path, excinfo):
os.chmod(path, stat.S_IWRITE)
func(path)
@contextlib.contextmanager
def WeakFileLock(lock_file: Union[str, Path]) -> Generator[BaseFileLock, None, None]:
"""A filelock with some custom logic.
This filelock is weaker than the default filelock in that:
1. It won't raise an exception if release fails.
2. It will default to a SoftFileLock if the filesystem does not support flock.
An INFO log message is emitted every 10 seconds if the lock is not acquired immediately.
"""
lock = FileLock(lock_file, timeout=constants.FILELOCK_LOG_EVERY_SECONDS)
while True:
try:
lock.acquire()
except Timeout:
logger.info("still waiting to acquire lock on %s", lock_file)
except NotImplementedError as e:
if "use SoftFileLock instead" in str(e):
# It's possible that the system does support flock, expect for one partition or filesystem.
# In this case, let's default to a SoftFileLock.
logger.warning(
"FileSystem does not appear to support flock. Falling back to SoftFileLock for %s", lock_file
)
lock = SoftFileLock(lock_file, timeout=constants.FILELOCK_LOG_EVERY_SECONDS)
continue
else:
break
yield lock
try:
return lock.release()
except OSError:
try:
Path(lock_file).unlink()
except OSError:
pass

View File

@ -0,0 +1,121 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to manage Git credentials."""
import re
import subprocess
from typing import List, Optional
from ..constants import ENDPOINT
from ._subprocess import run_interactive_subprocess, run_subprocess
GIT_CREDENTIAL_REGEX = re.compile(
r"""
^\s* # start of line
credential\.helper # credential.helper value
\s*=\s* # separator
(\w+) # the helper name (group 1)
(\s|$) # whitespace or end of line
""",
flags=re.MULTILINE | re.IGNORECASE | re.VERBOSE,
)
def list_credential_helpers(folder: Optional[str] = None) -> List[str]:
"""Return the list of git credential helpers configured.
See https://git-scm.com/docs/gitcredentials.
Credentials are saved in all configured helpers (store, cache, macOS keychain,...).
Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential.
Args:
folder (`str`, *optional*):
The folder in which to check the configured helpers.
"""
try:
output = run_subprocess("git config --list", folder=folder).stdout
parsed = _parse_credential_output(output)
return parsed
except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)
def set_git_credential(token: str, username: str = "hf_user", folder: Optional[str] = None) -> None:
"""Save a username/token pair in git credential for HF Hub registry.
Credentials are saved in all configured helpers (store, cache, macOS keychain,...).
Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential.
Args:
username (`str`, defaults to `"hf_user"`):
A git username. Defaults to `"hf_user"`, the default user used in the Hub.
token (`str`, defaults to `"hf_user"`):
A git password. In practice, the User Access Token for the Hub.
See https://huggingface.co/settings/tokens.
folder (`str`, *optional*):
The folder in which to check the configured helpers.
"""
with run_interactive_subprocess("git credential approve", folder=folder) as (
stdin,
_,
):
stdin.write(f"url={ENDPOINT}\nusername={username.lower()}\npassword={token}\n\n")
stdin.flush()
def unset_git_credential(username: str = "hf_user", folder: Optional[str] = None) -> None:
"""Erase credentials from git credential for HF Hub registry.
Credentials are erased from the configured helpers (store, cache, macOS
keychain,...), if any. If `username` is not provided, any credential configured for
HF Hub endpoint is erased.
Calls "`git credential erase`" internally. See https://git-scm.com/docs/git-credential.
Args:
username (`str`, defaults to `"hf_user"`):
A git username. Defaults to `"hf_user"`, the default user used in the Hub.
folder (`str`, *optional*):
The folder in which to check the configured helpers.
"""
with run_interactive_subprocess("git credential reject", folder=folder) as (
stdin,
_,
):
standard_input = f"url={ENDPOINT}\n"
if username is not None:
standard_input += f"username={username.lower()}\n"
standard_input += "\n"
stdin.write(standard_input)
stdin.flush()
def _parse_credential_output(output: str) -> List[str]:
"""Parse the output of `git credential fill` to extract the password.
Args:
output (`str`):
The output of `git credential fill`.
"""
# NOTE: If user has set an helper for a custom URL, it will not we caught here.
# Example: `credential.https://huggingface.co.helper=store`
# See: https://github.com/huggingface/huggingface_hub/pull/1138#discussion_r1013324508
return sorted( # Sort for nice printing
set( # Might have some duplicates
match[0] for match in GIT_CREDENTIAL_REGEX.findall(output)
)
)

View File

@ -0,0 +1,239 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle headers to send in calls to Huggingface Hub."""
from typing import Dict, Optional, Union
from huggingface_hub.errors import LocalTokenNotFoundError
from .. import constants
from ._auth import get_token
from ._runtime import (
get_fastai_version,
get_fastcore_version,
get_hf_hub_version,
get_python_version,
get_tf_version,
get_torch_version,
is_fastai_available,
is_fastcore_available,
is_tf_available,
is_torch_available,
)
from ._validators import validate_hf_hub_args
@validate_hf_hub_args
def build_hf_headers(
*,
token: Optional[Union[bool, str]] = None,
is_write_action: bool = False,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""
Build headers dictionary to send in a HF Hub call.
By default, authorization token is always provided either from argument (explicit
use) or retrieved from the cache (implicit use). To explicitly avoid sending the
token to the Hub, set `token=False` or set the `HF_HUB_DISABLE_IMPLICIT_TOKEN`
environment variable.
In case of an API call that requires write access, an error is thrown if token is
`None` or token is an organization token (starting with `"api_org***"`).
In addition to the auth header, a user-agent is added to provide information about
the installed packages (versions of python, huggingface_hub, torch, tensorflow,
fastai and fastcore).
Args:
token (`str`, `bool`, *optional*):
The token to be sent in authorization header for the Hub call:
- if a string, it is used as the Hugging Face token
- if `True`, the token is read from the machine (cache or env variable)
- if `False`, authorization header is not set
- if `None`, the token is read from the machine only except if
`HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set.
is_write_action (`bool`, default to `False`):
Set to True if the API call requires a write access. If `True`, the token
will be validated (cannot be `None`, cannot start by `"api_org***"`).
library_name (`str`, *optional*):
The name of the library that is making the HTTP request. Will be added to
the user-agent header.
library_version (`str`, *optional*):
The version of the library that is making the HTTP request. Will be added
to the user-agent header.
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string. It will
be completed with information about the installed packages.
headers (`dict`, *optional*):
Additional headers to include in the request. Those headers take precedence
over the ones generated by this function.
Returns:
A `Dict` of headers to pass in your API call.
Example:
```py
>>> build_hf_headers(token="hf_***") # explicit token
{"authorization": "Bearer hf_***", "user-agent": ""}
>>> build_hf_headers(token=True) # explicitly use cached token
{"authorization": "Bearer hf_***",...}
>>> build_hf_headers(token=False) # explicitly don't use cached token
{"user-agent": ...}
>>> build_hf_headers() # implicit use of the cached token
{"authorization": "Bearer hf_***",...}
# HF_HUB_DISABLE_IMPLICIT_TOKEN=True # to set as env variable
>>> build_hf_headers() # token is not sent
{"user-agent": ...}
>>> build_hf_headers(token="api_org_***", is_write_action=True)
ValueError: You must use your personal account token for write-access methods.
>>> build_hf_headers(library_name="transformers", library_version="1.2.3")
{"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"}
```
Raises:
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If organization token is passed and "write" access is required.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If "write" access is required but token is not passed and not saved locally.
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
If `token=True` but token is not saved locally.
"""
# Get auth token to send
token_to_send = get_token_to_send(token)
_validate_token_to_send(token_to_send, is_write_action=is_write_action)
# Combine headers
hf_headers = {
"user-agent": _http_user_agent(
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
)
}
if token_to_send is not None:
hf_headers["authorization"] = f"Bearer {token_to_send}"
if headers is not None:
hf_headers.update(headers)
return hf_headers
def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]:
"""Select the token to send from either `token` or the cache."""
# Case token is explicitly provided
if isinstance(token, str):
return token
# Case token is explicitly forbidden
if token is False:
return None
# Token is not provided: we get it from local cache
cached_token = get_token()
# Case token is explicitly required
if token is True:
if cached_token is None:
raise LocalTokenNotFoundError(
"Token is required (`token=True`), but no token found. You"
" need to provide a token or be logged in to Hugging Face with"
" `huggingface-cli login` or `huggingface_hub.login`. See"
" https://huggingface.co/settings/tokens."
)
return cached_token
# Case implicit use of the token is forbidden by env variable
if constants.HF_HUB_DISABLE_IMPLICIT_TOKEN:
return None
# Otherwise: we use the cached token as the user has not explicitly forbidden it
return cached_token
def _validate_token_to_send(token: Optional[str], is_write_action: bool) -> None:
if is_write_action:
if token is None:
raise ValueError(
"Token is required (write-access action) but no token found. You need"
" to provide a token or be logged in to Hugging Face with"
" `huggingface-cli login` or `huggingface_hub.login`. See"
" https://huggingface.co/settings/tokens."
)
if token.startswith("api_org"):
raise ValueError(
"You must use your personal account token for write-access methods. To"
" generate a write-access token, go to"
" https://huggingface.co/settings/tokens"
)
def _http_user_agent(
*,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
) -> str:
"""Format a user-agent string containing information about the installed packages.
Args:
library_name (`str`, *optional*):
The name of the library that is making the HTTP request.
library_version (`str`, *optional*):
The version of the library that is making the HTTP request.
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string.
Returns:
The formatted user-agent string.
"""
if library_name is not None:
ua = f"{library_name}/{library_version}"
else:
ua = "unknown/None"
ua += f"; hf_hub/{get_hf_hub_version()}"
ua += f"; python/{get_python_version()}"
if not constants.HF_HUB_DISABLE_TELEMETRY:
if is_torch_available():
ua += f"; torch/{get_torch_version()}"
if is_tf_available():
ua += f"; tensorflow/{get_tf_version()}"
if is_fastai_available():
ua += f"; fastai/{get_fastai_version()}"
if is_fastcore_available():
ua += f"; fastcore/{get_fastcore_version()}"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
return _deduplicate_user_agent(ua)
def _deduplicate_user_agent(user_agent: str) -> str:
"""Deduplicate redundant information in the generated user-agent."""
# Split around ";" > Strip whitespaces > Store as dict keys (ensure unicity) > format back as string
# Order is implicitly preserved by dictionary structure (see https://stackoverflow.com/a/53657523).
return "; ".join({key.strip(): None for key in user_agent.split(";")}.keys())

View File

@ -0,0 +1,96 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contain helper class to retrieve/store token from/to local cache."""
import warnings
from pathlib import Path
from typing import Optional
from .. import constants
from ._auth import get_token
class HfFolder:
path_token = Path(constants.HF_TOKEN_PATH)
# Private attribute. Will be removed in v0.15
_old_path_token = Path(constants._OLD_HF_TOKEN_PATH)
# TODO: deprecate when adapted in transformers/datasets/gradio
# @_deprecate_method(version="1.0", message="Use `huggingface_hub.login` instead.")
@classmethod
def save_token(cls, token: str) -> None:
"""
Save token, creating folder as needed.
Token is saved in the huggingface home folder. You can configure it by setting
the `HF_HOME` environment variable.
Args:
token (`str`):
The token to save to the [`HfFolder`]
"""
cls.path_token.parent.mkdir(parents=True, exist_ok=True)
cls.path_token.write_text(token)
# TODO: deprecate when adapted in transformers/datasets/gradio
# @_deprecate_method(version="1.0", message="Use `huggingface_hub.get_token` instead.")
@classmethod
def get_token(cls) -> Optional[str]:
"""
Get token or None if not existent.
This method is deprecated in favor of [`huggingface_hub.get_token`] but is kept for backward compatibility.
Its behavior is the same as [`huggingface_hub.get_token`].
Returns:
`str` or `None`: The token, `None` if it doesn't exist.
"""
# 0. Check if token exist in old path but not new location
try:
cls._copy_to_new_path_and_warn()
except Exception: # if not possible (e.g. PermissionError), do not raise
pass
return get_token()
# TODO: deprecate when adapted in transformers/datasets/gradio
# @_deprecate_method(version="1.0", message="Use `huggingface_hub.logout` instead.")
@classmethod
def delete_token(cls) -> None:
"""
Deletes the token from storage. Does not fail if token does not exist.
"""
try:
cls.path_token.unlink()
except FileNotFoundError:
pass
try:
cls._old_path_token.unlink()
except FileNotFoundError:
pass
@classmethod
def _copy_to_new_path_and_warn(cls):
if cls._old_path_token.exists() and not cls.path_token.exists():
cls.save_token(cls._old_path_token.read_text())
warnings.warn(
f"A token has been found in `{cls._old_path_token}`. This is the old"
" path where tokens were stored. The new location is"
f" `{cls.path_token}` which is configurable using `HF_HOME` environment"
" variable. Your token has been copied to this new location. You can"
" now safely delete the old token file manually or use"
" `huggingface-cli logout`."
)

View File

@ -0,0 +1,551 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle HTTP requests in Huggingface Hub."""
import io
import os
import re
import threading
import time
import uuid
from functools import lru_cache
from http import HTTPStatus
from typing import Callable, Optional, Tuple, Type, Union
import requests
from requests import HTTPError, Response
from requests.adapters import HTTPAdapter
from requests.models import PreparedRequest
from huggingface_hub.errors import OfflineModeIsEnabled
from .. import constants
from ..errors import (
BadRequestError,
DisabledRepoError,
EntryNotFoundError,
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from . import logging
from ._fixes import JSONDecodeError
from ._lfs import SliceFileObj
from ._typing import HTTP_METHOD_T
logger = logging.get_logger(__name__)
# Both headers are used by the Hub to debug failed requests.
# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB.
# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well.
X_AMZN_TRACE_ID = "X-Amzn-Trace-Id"
X_REQUEST_ID = "x-request-id"
REPO_API_REGEX = re.compile(
r"""
# staging or production endpoint
^https://[^/]+
(
# on /api/repo_type/repo_id
/api/(models|datasets|spaces)/(.+)
|
# or /repo_id/resolve/revision/...
/(.+)/resolve/(.+)
)
""",
flags=re.VERBOSE,
)
class UniqueRequestIdAdapter(HTTPAdapter):
X_AMZN_TRACE_ID = "X-Amzn-Trace-Id"
def add_headers(self, request, **kwargs):
super().add_headers(request, **kwargs)
# Add random request ID => easier for server-side debug
if X_AMZN_TRACE_ID not in request.headers:
request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
# Add debug log
has_token = str(request.headers.get("authorization", "")).startswith("Bearer hf_")
logger.debug(
f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})"
)
def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
"""Catch any RequestException to append request id to the error message for debugging."""
try:
return super().send(request, *args, **kwargs)
except requests.RequestException as e:
request_id = request.headers.get(X_AMZN_TRACE_ID)
if request_id is not None:
# Taken from https://stackoverflow.com/a/58270258
e.args = (*e.args, f"(Request ID: {request_id})")
raise
class OfflineAdapter(HTTPAdapter):
def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
raise OfflineModeIsEnabled(
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
)
def _default_backend_factory() -> requests.Session:
session = requests.Session()
if constants.HF_HUB_OFFLINE:
session.mount("http://", OfflineAdapter())
session.mount("https://", OfflineAdapter())
else:
session.mount("http://", UniqueRequestIdAdapter())
session.mount("https://", UniqueRequestIdAdapter())
return session
BACKEND_FACTORY_T = Callable[[], requests.Session]
_GLOBAL_BACKEND_FACTORY: BACKEND_FACTORY_T = _default_backend_factory
def configure_http_backend(backend_factory: BACKEND_FACTORY_T = _default_backend_factory) -> None:
"""
Configure the HTTP backend by providing a `backend_factory`. Any HTTP calls made by `huggingface_hub` will use a
Session object instantiated by this factory. This can be useful if you are running your scripts in a specific
environment requiring custom configuration (e.g. custom proxy or certifications).
Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe,
`huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory`
set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between
calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned.
See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`.
Example:
```py
import requests
from huggingface_hub import configure_http_backend, get_session
# Create a factory function that returns a Session with configured proxies
def backend_factory() -> requests.Session:
session = requests.Session()
session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"}
return session
# Set it as the default session factory
configure_http_backend(backend_factory=backend_factory)
# In practice, this is mostly done internally in `huggingface_hub`
session = get_session()
```
"""
global _GLOBAL_BACKEND_FACTORY
_GLOBAL_BACKEND_FACTORY = backend_factory
reset_sessions()
def get_session() -> requests.Session:
"""
Get a `requests.Session` object, using the session factory from the user.
Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe,
`huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory`
set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between
calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned.
See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`.
Example:
```py
import requests
from huggingface_hub import configure_http_backend, get_session
# Create a factory function that returns a Session with configured proxies
def backend_factory() -> requests.Session:
session = requests.Session()
session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"}
return session
# Set it as the default session factory
configure_http_backend(backend_factory=backend_factory)
# In practice, this is mostly done internally in `huggingface_hub`
session = get_session()
```
"""
return _get_session_from_cache(process_id=os.getpid(), thread_id=threading.get_ident())
def reset_sessions() -> None:
"""Reset the cache of sessions.
Mostly used internally when sessions are reconfigured or an SSLError is raised.
See [`configure_http_backend`] for more details.
"""
_get_session_from_cache.cache_clear()
@lru_cache
def _get_session_from_cache(process_id: int, thread_id: int) -> requests.Session:
"""
Create a new session per thread using global factory. Using LRU cache (maxsize 128) to avoid memory leaks when
using thousands of threads. Cache is cleared when `configure_http_backend` is called.
"""
return _GLOBAL_BACKEND_FACTORY()
def http_backoff(
method: HTTP_METHOD_T,
url: str,
*,
max_retries: int = 5,
base_wait_time: float = 1,
max_wait_time: float = 8,
retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (
requests.Timeout,
requests.ConnectionError,
),
retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE,
**kwargs,
) -> Response:
"""Wrapper around requests to retry calls on an endpoint, with exponential backoff.
Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...)
and/or on specific status codes (ex: service unavailable). If the call failed more
than `max_retries`, the exception is thrown or `raise_for_status` is called on the
response object.
Re-implement mechanisms from the `backoff` library to avoid adding an external
dependencies to `hugging_face_hub`. See https://github.com/litl/backoff.
Args:
method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`):
HTTP method to perform.
url (`str`):
The URL of the resource to fetch.
max_retries (`int`, *optional*, defaults to `5`):
Maximum number of retries, defaults to 5 (no retries).
base_wait_time (`float`, *optional*, defaults to `1`):
Duration (in seconds) to wait before retrying the first time.
Wait time between retries then grows exponentially, capped by
`max_wait_time`.
max_wait_time (`float`, *optional*, defaults to `8`):
Maximum duration (in seconds) to wait before retrying.
retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*):
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
By default, retry on `requests.Timeout` and `requests.ConnectionError`.
retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`):
Define on which status codes the request must be retried. By default, only
HTTP 503 Service Unavailable is retried.
**kwargs (`dict`, *optional*):
kwargs to pass to `requests.request`.
Example:
```
>>> from huggingface_hub.utils import http_backoff
# Same usage as "requests.request".
>>> response = http_backoff("GET", "https://www.google.com")
>>> response.raise_for_status()
# If you expect a Gateway Timeout from time to time
>>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504)
>>> response.raise_for_status()
```
<Tip warning={true}>
When using `requests` it is possible to stream data by passing an iterator to the
`data` argument. On http backoff this is a problem as the iterator is not reset
after a failed call. This issue is mitigated for file objects or any IO streams
by saving the initial position of the cursor (with `data.tell()`) and resetting the
cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff
will fail. If this is a hard constraint for you, please let us know by opening an
issue on [Github](https://github.com/huggingface/huggingface_hub).
</Tip>
"""
if isinstance(retry_on_exceptions, type): # Tuple from single exception type
retry_on_exceptions = (retry_on_exceptions,)
if isinstance(retry_on_status_codes, int): # Tuple from single status code
retry_on_status_codes = (retry_on_status_codes,)
nb_tries = 0
sleep_time = base_wait_time
# If `data` is used and is a file object (or any IO), it will be consumed on the
# first HTTP request. We need to save the initial position so that the full content
# of the file is re-sent on http backoff. See warning tip in docstring.
io_obj_initial_pos = None
if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)):
io_obj_initial_pos = kwargs["data"].tell()
session = get_session()
while True:
nb_tries += 1
try:
# If `data` is used and is a file object (or any IO), set back cursor to
# initial position.
if io_obj_initial_pos is not None:
kwargs["data"].seek(io_obj_initial_pos)
# Perform request and return if status_code is not in the retry list.
response = session.request(method=method, url=url, **kwargs)
if response.status_code not in retry_on_status_codes:
return response
# Wrong status code returned (HTTP 503 for instance)
logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}")
if nb_tries > max_retries:
response.raise_for_status() # Will raise uncaught exception
# We return response to avoid infinite loop in the corner case where the
# user ask for retry on a status code that doesn't raise_for_status.
return response
except retry_on_exceptions as err:
logger.warning(f"'{err}' thrown while requesting {method} {url}")
if isinstance(err, requests.ConnectionError):
reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
if nb_tries > max_retries:
raise err
# Sleep for X seconds
logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].")
time.sleep(sleep_time)
# Update sleep time for next retry
sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff
def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str:
"""Replace the default endpoint in a URL by a custom one.
This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint.
"""
endpoint = endpoint or constants.ENDPOINT
# check if a proxy has been set => if yes, update the returned URL to use the proxy
if endpoint not in (None, constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT):
url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint)
url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint)
return url
def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None:
"""
Internal version of `response.raise_for_status()` that will refine a
potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`.
This helper is meant to be the unique method to raise_for_status when making a call
to the Hugging Face Hub.
Example:
```py
import requests
from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError
response = get_session().post(...)
try:
hf_raise_for_status(response)
except HfHubHTTPError as e:
print(str(e)) # formatted message
e.request_id, e.server_message # details returned by server
# Complete the error message with additional information once it's raised
e.append_to_message("\n`create_commit` expects the repository to exist.")
raise
```
Args:
response (`Response`):
Response from the server.
endpoint_name (`str`, *optional*):
Name of the endpoint that has been called. If provided, the error message
will be more complete.
<Tip warning={true}>
Raises when the request has failed:
- [`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it
doesn't exist, because `repo_type` is not set correctly, or because the repo
is `private` and you do not have access.
- [`~utils.GatedRepoError`]
If the repository exists but is gated and the user is not on the authorized
list.
- [`~utils.RevisionNotFoundError`]
If the repository exists but the revision couldn't be find.
- [`~utils.EntryNotFoundError`]
If the repository exists but the entry (e.g. the requested file) couldn't be
find.
- [`~utils.BadRequestError`]
If request failed with a HTTP 400 BadRequest error.
- [`~utils.HfHubHTTPError`]
If request failed for a reason not listed above.
</Tip>
"""
try:
response.raise_for_status()
except HTTPError as e:
error_code = response.headers.get("X-Error-Code")
error_message = response.headers.get("X-Error-Message")
if error_code == "RevisionNotFound":
message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}."
raise _format(RevisionNotFoundError, message, response) from e
elif error_code == "EntryNotFound":
message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}."
raise _format(EntryNotFoundError, message, response) from e
elif error_code == "GatedRepo":
message = (
f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}."
)
raise _format(GatedRepoError, message, response) from e
elif error_message == "Access to this resource is disabled.":
message = (
f"{response.status_code} Client Error."
+ "\n\n"
+ f"Cannot access repository for url {response.url}."
+ "\n"
+ "Access to this resource is disabled."
)
raise _format(DisabledRepoError, message, response) from e
elif error_code == "RepoNotFound" or (
response.status_code == 401
and response.request is not None
and response.request.url is not None
and REPO_API_REGEX.search(response.request.url) is not None
):
# 401 is misleading as it is returned for:
# - private and gated repos if user is not authenticated
# - missing repos
# => for now, we process them as `RepoNotFound` anyway.
# See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
message = (
f"{response.status_code} Client Error."
+ "\n\n"
+ f"Repository Not Found for url: {response.url}."
+ "\nPlease make sure you specified the correct `repo_id` and"
" `repo_type`.\nIf you are trying to access a private or gated repo,"
" make sure you are authenticated."
)
raise _format(RepositoryNotFoundError, message, response) from e
elif response.status_code == 400:
message = (
f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:"
)
raise _format(BadRequestError, message, response) from e
elif response.status_code == 403:
message = (
f"\n\n{response.status_code} Forbidden: {error_message}."
+ f"\nCannot access content at: {response.url}."
+ "\nMake sure your token has the correct permissions."
)
raise _format(HfHubHTTPError, message, response) from e
elif response.status_code == 416:
range_header = response.request.headers.get("Range")
message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}."
raise _format(HfHubHTTPError, message, response) from e
# Convert `HTTPError` into a `HfHubHTTPError` to display request information
# as well (request id and/or server error message)
raise _format(HfHubHTTPError, str(e), response) from e
def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError:
server_errors = []
# Retrieve server error from header
from_headers = response.headers.get("X-Error-Message")
if from_headers is not None:
server_errors.append(from_headers)
# Retrieve server error from body
try:
# Case errors are returned in a JSON format
data = response.json()
error = data.get("error")
if error is not None:
if isinstance(error, list):
# Case {'error': ['my error 1', 'my error 2']}
server_errors.extend(error)
else:
# Case {'error': 'my error'}
server_errors.append(error)
errors = data.get("errors")
if errors is not None:
# Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]}
for error in errors:
if "message" in error:
server_errors.append(error["message"])
except JSONDecodeError:
# If content is not JSON and not HTML, append the text
content_type = response.headers.get("Content-Type", "")
if response.text and "html" not in content_type.lower():
server_errors.append(response.text)
# Strip all server messages
server_errors = [line.strip() for line in server_errors if line.strip()]
# Deduplicate server messages (keep order)
# taken from https://stackoverflow.com/a/17016257
server_errors = list(dict.fromkeys(server_errors))
# Format server error
server_message = "\n".join(server_errors)
# Add server error to custom message
final_error_message = custom_message
if server_message and server_message.lower() not in custom_message.lower():
if "\n\n" in custom_message:
final_error_message += "\n" + server_message
else:
final_error_message += "\n\n" + server_message
# Add Request ID
request_id = str(response.headers.get(X_REQUEST_ID, ""))
if request_id:
request_id_message = f" (Request ID: {request_id})"
else:
# Fallback to X-Amzn-Trace-Id
request_id = str(response.headers.get(X_AMZN_TRACE_ID, ""))
if request_id:
request_id_message = f" (Amzn Trace ID: {request_id})"
if request_id and request_id.lower() not in final_error_message.lower():
if "\n" in final_error_message:
newline_index = final_error_message.index("\n")
final_error_message = (
final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:]
)
else:
final_error_message += request_id_message
# Return
return error_type(final_error_message.strip(), response=response, server_message=server_message or None)

View File

@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Git LFS related utilities"""
import io
import os
from contextlib import AbstractContextManager
from typing import BinaryIO
class SliceFileObj(AbstractContextManager):
"""
Utility context manager to read a *slice* of a seekable file-like object as a seekable, file-like object.
This is NOT thread safe
Inspired by stackoverflow.com/a/29838711/593036
Credits to @julien-c
Args:
fileobj (`BinaryIO`):
A file-like object to slice. MUST implement `tell()` and `seek()` (and `read()` of course).
`fileobj` will be reset to its original position when exiting the context manager.
seek_from (`int`):
The start of the slice (offset from position 0 in bytes).
read_limit (`int`):
The maximum number of bytes to read from the slice.
Attributes:
previous_position (`int`):
The previous position
Examples:
Reading 200 bytes with an offset of 128 bytes from a file (ie bytes 128 to 327):
```python
>>> with open("path/to/file", "rb") as file:
... with SliceFileObj(file, seek_from=128, read_limit=200) as fslice:
... fslice.read(...)
```
Reading a file in chunks of 512 bytes
```python
>>> import os
>>> chunk_size = 512
>>> file_size = os.getsize("path/to/file")
>>> with open("path/to/file", "rb") as file:
... for chunk_idx in range(ceil(file_size / chunk_size)):
... with SliceFileObj(file, seek_from=chunk_idx * chunk_size, read_limit=chunk_size) as fslice:
... chunk = fslice.read(...)
```
"""
def __init__(self, fileobj: BinaryIO, seek_from: int, read_limit: int):
self.fileobj = fileobj
self.seek_from = seek_from
self.read_limit = read_limit
def __enter__(self):
self._previous_position = self.fileobj.tell()
end_of_stream = self.fileobj.seek(0, os.SEEK_END)
self._len = min(self.read_limit, end_of_stream - self.seek_from)
# ^^ The actual number of bytes that can be read from the slice
self.fileobj.seek(self.seek_from, io.SEEK_SET)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.fileobj.seek(self._previous_position, io.SEEK_SET)
def read(self, n: int = -1):
pos = self.tell()
if pos >= self._len:
return b""
remaining_amount = self._len - pos
data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount))
return data
def tell(self) -> int:
return self.fileobj.tell() - self.seek_from
def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
start = self.seek_from
end = start + self._len
if whence in (os.SEEK_SET, os.SEEK_END):
offset = start + offset if whence == os.SEEK_SET else end + offset
offset = max(start, min(offset, end))
whence = os.SEEK_SET
elif whence == os.SEEK_CUR:
cur_pos = self.fileobj.tell()
offset = max(start - cur_pos, min(offset, end - cur_pos))
else:
raise ValueError(f"whence value {whence} is not supported")
return self.fileobj.seek(offset, whence) - self.seek_from
def __iter__(self):
yield self.read(n=4 * 1024 * 1024)

View File

@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle pagination on Huggingface Hub."""
from typing import Dict, Iterable, Optional
import requests
from . import get_session, hf_raise_for_status, logging
logger = logging.get_logger(__name__)
def paginate(path: str, params: Dict, headers: Dict) -> Iterable:
"""Fetch a list of models/datasets/spaces and paginate through results.
This is using the same "Link" header format as GitHub.
See:
- https://requests.readthedocs.io/en/latest/api/#requests.Response.links
- https://docs.github.com/en/rest/guides/traversing-with-pagination#link-header
"""
session = get_session()
r = session.get(path, params=params, headers=headers)
hf_raise_for_status(r)
yield from r.json()
# Follow pages
# Next link already contains query params
next_page = _get_next_page(r)
while next_page is not None:
logger.debug(f"Pagination detected. Requesting next page: {next_page}")
r = session.get(next_page, headers=headers)
hf_raise_for_status(r)
yield from r.json()
next_page = _get_next_page(r)
def _get_next_page(response: requests.Response) -> Optional[str]:
return response.links.get("next", {}).get("url")

View File

@ -0,0 +1,141 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to handle paths in Huggingface Hub."""
from fnmatch import fnmatch
from pathlib import Path
from typing import Callable, Generator, Iterable, List, Optional, TypeVar, Union
T = TypeVar("T")
# Always ignore `.git` and `.cache/huggingface` folders in commits
DEFAULT_IGNORE_PATTERNS = [
".git",
".git/*",
"*/.git",
"**/.git/**",
".cache/huggingface",
".cache/huggingface/*",
"*/.cache/huggingface",
"**/.cache/huggingface/**",
]
# Forbidden to commit these folders
FORBIDDEN_FOLDERS = [".git", ".cache"]
def filter_repo_objects(
items: Iterable[T],
*,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
key: Optional[Callable[[T], str]] = None,
) -> Generator[T, None, None]:
"""Filter repo objects based on an allowlist and a denylist.
Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects.
In the later case, `key` must be provided and specifies a function of one argument
that is used to extract a path from each element in iterable.
Patterns are Unix shell-style wildcards which are NOT regular expressions. See
https://docs.python.org/3/library/fnmatch.html for more details.
Args:
items (`Iterable`):
List of items to filter.
allow_patterns (`str` or `List[str]`, *optional*):
Patterns constituting the allowlist. If provided, item paths must match at
least one pattern from the allowlist.
ignore_patterns (`str` or `List[str]`, *optional*):
Patterns constituting the denylist. If provided, item paths must not match
any patterns from the denylist.
key (`Callable[[T], str]`, *optional*):
Single-argument function to extract a path from each item. If not provided,
the `items` must already be `str` or `Path`.
Returns:
Filtered list of objects, as a generator.
Raises:
:class:`ValueError`:
If `key` is not provided and items are not `str` or `Path`.
Example usage with paths:
```python
>>> # Filter only PDFs that are not hidden.
>>> list(filter_repo_objects(
... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"],
... allow_patterns=["*.pdf"],
... ignore_patterns=[".*"],
... ))
["aaa.pdf"]
```
Example usage with objects:
```python
>>> list(filter_repo_objects(
... [
... CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")
... CommitOperationAdd(path_or_fileobj="/tmp/bbb.jpg", path_in_repo="bbb.jpg")
... CommitOperationAdd(path_or_fileobj="/tmp/.ccc.pdf", path_in_repo=".ccc.pdf")
... CommitOperationAdd(path_or_fileobj="/tmp/.ddd.png", path_in_repo=".ddd.png")
... ],
... allow_patterns=["*.pdf"],
... ignore_patterns=[".*"],
... key=lambda x: x.repo_in_path
... ))
[CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")]
```
"""
if isinstance(allow_patterns, str):
allow_patterns = [allow_patterns]
if isinstance(ignore_patterns, str):
ignore_patterns = [ignore_patterns]
if allow_patterns is not None:
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
if ignore_patterns is not None:
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
if key is None:
def _identity(item: T) -> str:
if isinstance(item, str):
return item
if isinstance(item, Path):
return str(item)
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
key = _identity # Items must be `str` or `Path`, otherwise raise ValueError
for item in items:
path = key(item)
# Skip if there's an allowlist and path doesn't match any
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
continue
# Skip if there's a denylist and path matches any
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
continue
yield item
def _add_wildcard_to_directories(pattern: str) -> str:
if pattern[-1] == "/":
return pattern + "*"
return pattern

View File

@ -0,0 +1,379 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Check presence of installed packages at runtime."""
import importlib.metadata
import os
import platform
import sys
import warnings
from typing import Any, Dict
from .. import __version__, constants
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
_package_versions = {}
_CANDIDATES = {
"aiohttp": {"aiohttp"},
"fastai": {"fastai"},
"fastapi": {"fastapi"},
"fastcore": {"fastcore"},
"gradio": {"gradio"},
"graphviz": {"graphviz"},
"hf_transfer": {"hf_transfer"},
"jinja": {"Jinja2"},
"keras": {"keras"},
"numpy": {"numpy"},
"pillow": {"Pillow"},
"pydantic": {"pydantic"},
"pydot": {"pydot"},
"safetensors": {"safetensors"},
"tensorboard": {"tensorboardX"},
"tensorflow": (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
),
"torch": {"torch"},
}
# Check once at runtime
for candidate_name, package_names in _CANDIDATES.items():
_package_versions[candidate_name] = "N/A"
for name in package_names:
try:
_package_versions[candidate_name] = importlib.metadata.version(name)
break
except importlib.metadata.PackageNotFoundError:
pass
def _get_version(package_name: str) -> str:
return _package_versions.get(package_name, "N/A")
def is_package_available(package_name: str) -> bool:
return _get_version(package_name) != "N/A"
# Python
def get_python_version() -> str:
return _PY_VERSION
# Huggingface Hub
def get_hf_hub_version() -> str:
return __version__
# aiohttp
def is_aiohttp_available() -> bool:
return is_package_available("aiohttp")
def get_aiohttp_version() -> str:
return _get_version("aiohttp")
# FastAI
def is_fastai_available() -> bool:
return is_package_available("fastai")
def get_fastai_version() -> str:
return _get_version("fastai")
# FastAPI
def is_fastapi_available() -> bool:
return is_package_available("fastapi")
def get_fastapi_version() -> str:
return _get_version("fastapi")
# Fastcore
def is_fastcore_available() -> bool:
return is_package_available("fastcore")
def get_fastcore_version() -> str:
return _get_version("fastcore")
# FastAI
def is_gradio_available() -> bool:
return is_package_available("gradio")
def get_gradio_version() -> str:
return _get_version("gradio")
# Graphviz
def is_graphviz_available() -> bool:
return is_package_available("graphviz")
def get_graphviz_version() -> str:
return _get_version("graphviz")
# hf_transfer
def is_hf_transfer_available() -> bool:
return is_package_available("hf_transfer")
def get_hf_transfer_version() -> str:
return _get_version("hf_transfer")
# keras
def is_keras_available() -> bool:
return is_package_available("keras")
def get_keras_version() -> str:
return _get_version("keras")
# Numpy
def is_numpy_available() -> bool:
return is_package_available("numpy")
def get_numpy_version() -> str:
return _get_version("numpy")
# Jinja
def is_jinja_available() -> bool:
return is_package_available("jinja")
def get_jinja_version() -> str:
return _get_version("jinja")
# Pillow
def is_pillow_available() -> bool:
return is_package_available("pillow")
def get_pillow_version() -> str:
return _get_version("pillow")
# Pydantic
def is_pydantic_available() -> bool:
if not is_package_available("pydantic"):
return False
# For Pydantic, we add an extra check to test whether it is correctly installed or not. If both pydantic 2.x and
# typing_extensions<=4.5.0 are installed, then pydantic will fail at import time. This should not happen when
# it is installed with `pip install huggingface_hub[inference]` but it can happen when it is installed manually
# by the user in an environment that we don't control.
#
# Usually we won't need to do this kind of check on optional dependencies. However, pydantic is a special case
# as it is automatically imported when doing `from huggingface_hub import ...` even if the user doesn't use it.
#
# See https://github.com/huggingface/huggingface_hub/pull/1829 for more details.
try:
from pydantic import validator # noqa: F401
except ImportError:
# Example: "ImportError: cannot import name 'TypeAliasType' from 'typing_extensions'"
warnings.warn(
"Pydantic is installed but cannot be imported. Please check your installation. `huggingface_hub` will "
"default to not using Pydantic. Error message: '{e}'"
)
return False
return True
def get_pydantic_version() -> str:
return _get_version("pydantic")
# Pydot
def is_pydot_available() -> bool:
return is_package_available("pydot")
def get_pydot_version() -> str:
return _get_version("pydot")
# Tensorboard
def is_tensorboard_available() -> bool:
return is_package_available("tensorboard")
def get_tensorboard_version() -> str:
return _get_version("tensorboard")
# Tensorflow
def is_tf_available() -> bool:
return is_package_available("tensorflow")
def get_tf_version() -> str:
return _get_version("tensorflow")
# Torch
def is_torch_available() -> bool:
return is_package_available("torch")
def get_torch_version() -> str:
return _get_version("torch")
# Safetensors
def is_safetensors_available() -> bool:
return is_package_available("safetensors")
# Shell-related helpers
try:
# Set to `True` if script is running in a Google Colab notebook.
# If running in Google Colab, git credential store is set globally which makes the
# warning disappear. See https://github.com/huggingface/huggingface_hub/issues/1043
#
# Taken from https://stackoverflow.com/a/63519730.
_is_google_colab = "google.colab" in str(get_ipython()) # type: ignore # noqa: F821
except NameError:
_is_google_colab = False
def is_notebook() -> bool:
"""Return `True` if code is executed in a notebook (Jupyter, Colab, QTconsole).
Taken from https://stackoverflow.com/a/39662359.
Adapted to make it work with Google colab as well.
"""
try:
shell_class = get_ipython().__class__ # type: ignore # noqa: F821
for parent_class in shell_class.__mro__: # e.g. "is subclass of"
if parent_class.__name__ == "ZMQInteractiveShell":
return True # Jupyter notebook, Google colab or qtconsole
return False
except NameError:
return False # Probably standard Python interpreter
def is_google_colab() -> bool:
"""Return `True` if code is executed in a Google colab.
Taken from https://stackoverflow.com/a/63519730.
"""
return _is_google_colab
def is_colab_enterprise() -> bool:
"""Return `True` if code is executed in a Google Colab Enterprise environment."""
return os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE"
def dump_environment_info() -> Dict[str, Any]:
"""Dump information about the machine to help debugging issues.
Similar helper exist in:
- `datasets` (https://github.com/huggingface/datasets/blob/main/src/datasets/commands/env.py)
- `diffusers` (https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/env.py)
- `transformers` (https://github.com/huggingface/transformers/blob/main/src/transformers/commands/env.py)
"""
from huggingface_hub import get_token, whoami
from huggingface_hub.utils import list_credential_helpers
token = get_token()
# Generic machine info
info: Dict[str, Any] = {
"huggingface_hub version": get_hf_hub_version(),
"Platform": platform.platform(),
"Python version": get_python_version(),
}
# Interpreter info
try:
shell_class = get_ipython().__class__ # type: ignore # noqa: F821
info["Running in iPython ?"] = "Yes"
info["iPython shell"] = shell_class.__name__
except NameError:
info["Running in iPython ?"] = "No"
info["Running in notebook ?"] = "Yes" if is_notebook() else "No"
info["Running in Google Colab ?"] = "Yes" if is_google_colab() else "No"
info["Running in Google Colab Enterprise ?"] = "Yes" if is_colab_enterprise() else "No"
# Login info
info["Token path ?"] = constants.HF_TOKEN_PATH
info["Has saved token ?"] = token is not None
if token is not None:
try:
info["Who am I ?"] = whoami()["name"]
except Exception:
pass
try:
info["Configured git credential helpers"] = ", ".join(list_credential_helpers())
except Exception:
pass
# Installed dependencies
info["FastAI"] = get_fastai_version()
info["Tensorflow"] = get_tf_version()
info["Torch"] = get_torch_version()
info["Jinja2"] = get_jinja_version()
info["Graphviz"] = get_graphviz_version()
info["keras"] = get_keras_version()
info["Pydot"] = get_pydot_version()
info["Pillow"] = get_pillow_version()
info["hf_transfer"] = get_hf_transfer_version()
info["gradio"] = get_gradio_version()
info["tensorboard"] = get_tensorboard_version()
info["numpy"] = get_numpy_version()
info["pydantic"] = get_pydantic_version()
info["aiohttp"] = get_aiohttp_version()
# Environment variables
info["ENDPOINT"] = constants.ENDPOINT
info["HF_HUB_CACHE"] = constants.HF_HUB_CACHE
info["HF_ASSETS_CACHE"] = constants.HF_ASSETS_CACHE
info["HF_TOKEN_PATH"] = constants.HF_TOKEN_PATH
info["HF_STORED_TOKENS_PATH"] = constants.HF_STORED_TOKENS_PATH
info["HF_HUB_OFFLINE"] = constants.HF_HUB_OFFLINE
info["HF_HUB_DISABLE_TELEMETRY"] = constants.HF_HUB_DISABLE_TELEMETRY
info["HF_HUB_DISABLE_PROGRESS_BARS"] = constants.HF_HUB_DISABLE_PROGRESS_BARS
info["HF_HUB_DISABLE_SYMLINKS_WARNING"] = constants.HF_HUB_DISABLE_SYMLINKS_WARNING
info["HF_HUB_DISABLE_EXPERIMENTAL_WARNING"] = constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING
info["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = constants.HF_HUB_DISABLE_IMPLICIT_TOKEN
info["HF_HUB_ENABLE_HF_TRANSFER"] = constants.HF_HUB_ENABLE_HF_TRANSFER
info["HF_HUB_ETAG_TIMEOUT"] = constants.HF_HUB_ETAG_TIMEOUT
info["HF_HUB_DOWNLOAD_TIMEOUT"] = constants.HF_HUB_DOWNLOAD_TIMEOUT
print("\nCopy-and-paste the text below in your GitHub issue.\n")
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
return info

View File

@ -0,0 +1,111 @@
import functools
import operator
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple
FILENAME_T = str
TENSOR_NAME_T = str
DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]
@dataclass
class TensorInfo:
"""Information about a tensor.
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
dtype (`str`):
The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
shape (`List[int]`):
The shape of the tensor.
data_offsets (`Tuple[int, int]`):
The offsets of the data in the file as a tuple `[BEGIN, END]`.
parameter_count (`int`):
The number of parameters in the tensor.
"""
dtype: DTYPE_T
shape: List[int]
data_offsets: Tuple[int, int]
parameter_count: int = field(init=False)
def __post_init__(self) -> None:
# Taken from https://stackoverflow.com/a/13840436
try:
self.parameter_count = functools.reduce(operator.mul, self.shape)
except TypeError:
self.parameter_count = 1 # scalar value has no shape
@dataclass
class SafetensorsFileMetadata:
"""Metadata for a Safetensors file hosted on the Hub.
This class is returned by [`parse_safetensors_file_metadata`].
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
metadata (`Dict`):
The metadata contained in the file.
tensors (`Dict[str, TensorInfo]`):
A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
[`TensorInfo`] object.
parameter_count (`Dict[str, int]`):
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
of that data type.
"""
metadata: Dict[str, str]
tensors: Dict[TENSOR_NAME_T, TensorInfo]
parameter_count: Dict[DTYPE_T, int] = field(init=False)
def __post_init__(self) -> None:
parameter_count: Dict[DTYPE_T, int] = defaultdict(int)
for tensor in self.tensors.values():
parameter_count[tensor.dtype] += tensor.parameter_count
self.parameter_count = dict(parameter_count)
@dataclass
class SafetensorsRepoMetadata:
"""Metadata for a Safetensors repo.
A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
model) or a 'model.safetensors.index.json' index file (sharded model) at its root.
This class is returned by [`get_safetensors_metadata`].
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
metadata (`Dict`, *optional*):
The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
models.
sharded (`bool`):
Whether the repo contains a sharded model or not.
weight_map (`Dict[str, str]`):
A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
files_metadata (`Dict[str, SafetensorsFileMetadata]`):
A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
a [`SafetensorsFileMetadata`] object.
parameter_count (`Dict[str, int]`):
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
of that data type.
"""
metadata: Optional[Dict]
sharded: bool
weight_map: Dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename
files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata
parameter_count: Dict[DTYPE_T, int] = field(init=False)
def __post_init__(self) -> None:
parameter_count: Dict[DTYPE_T, int] = defaultdict(int)
for file_metadata in self.files_metadata.values():
for dtype, nb_parameters_ in file_metadata.parameter_count.items():
parameter_count[dtype] += nb_parameters_
self.parameter_count = dict(parameter_count)

View File

@ -0,0 +1,142 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Contains utilities to easily handle subprocesses in `huggingface_hub`."""
import os
import subprocess
import sys
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from typing import IO, Generator, List, Optional, Tuple, Union
from .logging import get_logger
logger = get_logger(__name__)
@contextmanager
def capture_output() -> Generator[StringIO, None, None]:
"""Capture output that is printed to terminal.
Taken from https://stackoverflow.com/a/34738440
Example:
```py
>>> with capture_output() as output:
... print("hello world")
>>> assert output.getvalue() == "hello world\n"
```
"""
output = StringIO()
previous_output = sys.stdout
sys.stdout = output
yield output
sys.stdout = previous_output
def run_subprocess(
command: Union[str, List[str]],
folder: Optional[Union[str, Path]] = None,
check=True,
**kwargs,
) -> subprocess.CompletedProcess:
"""
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`,
please call `subprocess.run` manually in case you would like for them not to
be captured.
Args:
command (`str` or `List[str]`):
The command to execute as a string or list of strings.
folder (`str`, *optional*):
The folder in which to run the command. Defaults to current working
directory (from `os.getcwd()`).
check (`bool`, *optional*, defaults to `True`):
Setting `check` to `True` will raise a `subprocess.CalledProcessError`
when the subprocess has a non-zero exit code.
kwargs (`Dict[str]`):
Keyword arguments to be passed to the `subprocess.run` underlying command.
Returns:
`subprocess.CompletedProcess`: The completed process.
"""
if isinstance(command, str):
command = command.split()
if isinstance(folder, Path):
folder = str(folder)
return subprocess.run(
command,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=check,
encoding="utf-8",
errors="replace", # if not utf-8, replace char by <20>
cwd=folder or os.getcwd(),
**kwargs,
)
@contextmanager
def run_interactive_subprocess(
command: Union[str, List[str]],
folder: Optional[Union[str, Path]] = None,
**kwargs,
) -> Generator[Tuple[IO[str], IO[str]], None, None]:
"""Run a subprocess in an interactive mode in a context manager.
Args:
command (`str` or `List[str]`):
The command to execute as a string or list of strings.
folder (`str`, *optional*):
The folder in which to run the command. Defaults to current working
directory (from `os.getcwd()`).
kwargs (`Dict[str]`):
Keyword arguments to be passed to the `subprocess.run` underlying command.
Returns:
`Tuple[IO[str], IO[str]]`: A tuple with `stdin` and `stdout` to interact
with the process (input and output are utf-8 encoded).
Example:
```python
with _interactive_subprocess("git credential-store get") as (stdin, stdout):
# Write to stdin
stdin.write("url=hf.co\nusername=obama\n".encode("utf-8"))
stdin.flush()
# Read from stdout
output = stdout.read().decode("utf-8")
```
"""
if isinstance(command, str):
command = command.split()
with subprocess.Popen(
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
encoding="utf-8",
errors="replace", # if not utf-8, replace char by <20>
cwd=folder or os.getcwd(),
**kwargs,
) as process:
assert process.stdin is not None, "subprocess is opened as subprocess.PIPE"
assert process.stdout is not None, "subprocess is opened as subprocess.PIPE"
yield process.stdin, process.stdout

View File

@ -0,0 +1,126 @@
from queue import Queue
from threading import Lock, Thread
from typing import Dict, Optional, Union
from urllib.parse import quote
from .. import constants, logging
from . import build_hf_headers, get_session, hf_raise_for_status
logger = logging.get_logger(__name__)
# Telemetry is sent by a separate thread to avoid blocking the main thread.
# A daemon thread is started once and consume tasks from the _TELEMETRY_QUEUE.
# If the thread stops for some reason -shouldn't happen-, we restart a new one.
_TELEMETRY_THREAD: Optional[Thread] = None
_TELEMETRY_THREAD_LOCK = Lock() # Lock to avoid starting multiple threads in parallel
_TELEMETRY_QUEUE: Queue = Queue()
def send_telemetry(
topic: str,
*,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
) -> None:
"""
Sends telemetry that helps tracking usage of different HF libraries.
This usage data helps us debug issues and prioritize new features. However, we understand that not everyone wants
to share additional information, and we respect your privacy. You can disable telemetry collection by setting the
`HF_HUB_DISABLE_TELEMETRY=1` as environment variable. Telemetry is also disabled in offline mode (i.e. when setting
`HF_HUB_OFFLINE=1`).
Telemetry collection is run in a separate thread to minimize impact for the user.
Args:
topic (`str`):
Name of the topic that is monitored. The topic is directly used to build the URL. If you want to monitor
subtopics, just use "/" separation. Examples: "gradio", "transformers/examples",...
library_name (`str`, *optional*):
The name of the library that is making the HTTP request. Will be added to the user-agent header.
library_version (`str`, *optional*):
The version of the library that is making the HTTP request. Will be added to the user-agent header.
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string. It will be completed with information about the installed packages.
Example:
```py
>>> from huggingface_hub.utils import send_telemetry
# Send telemetry without library information
>>> send_telemetry("ping")
# Send telemetry to subtopic with library information
>>> send_telemetry("gradio/local_link", library_name="gradio", library_version="3.22.1")
# Send telemetry with additional data
>>> send_telemetry(
... topic="examples",
... library_name="transformers",
... library_version="4.26.0",
... user_agent={"pipeline": "text_classification", "framework": "flax"},
... )
```
"""
if constants.HF_HUB_OFFLINE or constants.HF_HUB_DISABLE_TELEMETRY:
return
_start_telemetry_thread() # starts thread only if doesn't exist yet
_TELEMETRY_QUEUE.put(
{"topic": topic, "library_name": library_name, "library_version": library_version, "user_agent": user_agent}
)
def _start_telemetry_thread():
"""Start a daemon thread to consume tasks from the telemetry queue.
If the thread is interrupted, start a new one.
"""
with _TELEMETRY_THREAD_LOCK: # avoid to start multiple threads if called concurrently
global _TELEMETRY_THREAD
if _TELEMETRY_THREAD is None or not _TELEMETRY_THREAD.is_alive():
_TELEMETRY_THREAD = Thread(target=_telemetry_worker, daemon=True)
_TELEMETRY_THREAD.start()
def _telemetry_worker():
"""Wait for a task and consume it."""
while True:
kwargs = _TELEMETRY_QUEUE.get()
_send_telemetry_in_thread(**kwargs)
_TELEMETRY_QUEUE.task_done()
def _send_telemetry_in_thread(
topic: str,
*,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
) -> None:
"""Contains the actual data sending data to the Hub.
This function is called directly in gradio's analytics because
it is not possible to send telemetry from a daemon thread.
See here: https://github.com/gradio-app/gradio/pull/8180
Please do not rename or remove this function.
"""
path = "/".join(quote(part) for part in topic.split("/") if len(part) > 0)
try:
r = get_session().head(
f"{constants.ENDPOINT}/api/telemetry/{path}",
headers=build_hf_headers(
token=False, # no need to send a token for telemetry
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
),
)
hf_raise_for_status(r)
except Exception as e:
# We don't want to error in case of connection errors of any kind.
logger.debug(f"Error while sending telemetry: {e}")

View File

@ -0,0 +1,75 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handle typing imports based on system compatibility."""
import sys
from typing import Any, Callable, List, Literal, Type, TypeVar, Union, get_args, get_origin
UNION_TYPES: List[Any] = [Union]
if sys.version_info >= (3, 10):
from types import UnionType
UNION_TYPES += [UnionType]
HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
# type hint meaning "function signature not changed by decorator"
CallableT = TypeVar("CallableT", bound=Callable)
_JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None))
def is_jsonable(obj: Any) -> bool:
"""Check if an object is JSON serializable.
This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object.
It works correctly for basic use cases but do not guarantee an exhaustive check.
Object is considered to be recursively json serializable if:
- it is an instance of int, float, str, bool, or NoneType
- it is a list or tuple and all its items are json serializable
- it is a dict and all its keys are strings and all its values are json serializable
"""
try:
if isinstance(obj, _JSON_SERIALIZABLE_TYPES):
return True
if isinstance(obj, (list, tuple)):
return all(is_jsonable(item) for item in obj)
if isinstance(obj, dict):
return all(isinstance(key, str) and is_jsonable(value) for key, value in obj.items())
if hasattr(obj, "__json__"):
return True
return False
except RecursionError:
return False
def is_simple_optional_type(type_: Type) -> bool:
"""Check if a type is optional, i.e. Optional[Type] or Union[Type, None] or Type | None, where Type is a non-composite type."""
if get_origin(type_) in UNION_TYPES:
union_args = get_args(type_)
if len(union_args) == 2 and type(None) in union_args:
return True
return False
def unwrap_simple_optional_type(optional_type: Type) -> Type:
"""Unwraps a simple optional type, i.e. returns Type from Optional[Type]."""
for arg in get_args(optional_type):
if arg is not type(None):
return arg
raise ValueError(f"'{optional_type}' is not an optional type")

View File

@ -0,0 +1,226 @@
# coding=utf-8
# Copyright 2022-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities to validate argument values in `huggingface_hub`."""
import inspect
import re
import warnings
from functools import wraps
from itertools import chain
from typing import Any, Dict
from huggingface_hub.errors import HFValidationError
from ._typing import CallableT
REPO_ID_REGEX = re.compile(
r"""
^
(\b[\w\-.]+\b/)? # optional namespace (username or organization)
\b # starts with a word boundary
[\w\-.]{1,96} # repo_name: alphanumeric + . _ -
\b # ends with a word boundary
$
""",
flags=re.VERBOSE,
)
def validate_hf_hub_args(fn: CallableT) -> CallableT:
"""Validate values received as argument for any public method of `huggingface_hub`.
The goal of this decorator is to harmonize validation of arguments reused
everywhere. By default, all defined validators are tested.
Validators:
- [`~utils.validate_repo_id`]: `repo_id` must be `"repo_name"`
or `"namespace/repo_name"`. Namespace is a username or an organization.
- [`~utils.smoothly_deprecate_use_auth_token`]: Use `token` instead of
`use_auth_token` (only if `use_auth_token` is not expected by the decorated
function - in practice, always the case in `huggingface_hub`).
Example:
```py
>>> from huggingface_hub.utils import validate_hf_hub_args
>>> @validate_hf_hub_args
... def my_cool_method(repo_id: str):
... print(repo_id)
>>> my_cool_method(repo_id="valid_repo_id")
valid_repo_id
>>> my_cool_method("other..repo..id")
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
>>> my_cool_method(repo_id="other..repo..id")
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
>>> @validate_hf_hub_args
... def my_cool_auth_method(token: str):
... print(token)
>>> my_cool_auth_method(token="a token")
"a token"
>>> my_cool_auth_method(use_auth_token="a use_auth_token")
"a use_auth_token"
>>> my_cool_auth_method(token="a token", use_auth_token="a use_auth_token")
UserWarning: Both `token` and `use_auth_token` are passed (...)
"a token"
```
Raises:
[`~utils.HFValidationError`]:
If an input is not valid.
"""
# TODO: add an argument to opt-out validation for specific argument?
signature = inspect.signature(fn)
# Should the validator switch `use_auth_token` values to `token`? In practice, always
# True in `huggingface_hub`. Might not be the case in a downstream library.
check_use_auth_token = "use_auth_token" not in signature.parameters and "token" in signature.parameters
@wraps(fn)
def _inner_fn(*args, **kwargs):
has_token = False
for arg_name, arg_value in chain(
zip(signature.parameters, args), # Args values
kwargs.items(), # Kwargs values
):
if arg_name in ["repo_id", "from_id", "to_id"]:
validate_repo_id(arg_value)
elif arg_name == "token" and arg_value is not None:
has_token = True
if check_use_auth_token:
kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
return fn(*args, **kwargs)
return _inner_fn # type: ignore
def validate_repo_id(repo_id: str) -> None:
"""Validate `repo_id` is valid.
This is not meant to replace the proper validation made on the Hub but rather to
avoid local inconsistencies whenever possible (example: passing `repo_type` in the
`repo_id` is forbidden).
Rules:
- Between 1 and 96 characters.
- Either "repo_name" or "namespace/repo_name"
- [a-zA-Z0-9] or "-", "_", "."
- "--" and ".." are forbidden
Valid: `"foo"`, `"foo/bar"`, `"123"`, `"Foo-BAR_foo.bar123"`
Not valid: `"datasets/foo/bar"`, `".repo_id"`, `"foo--bar"`, `"foo.git"`
Example:
```py
>>> from huggingface_hub.utils import validate_repo_id
>>> validate_repo_id(repo_id="valid_repo_id")
>>> validate_repo_id(repo_id="other..repo..id")
huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'.
```
Discussed in https://github.com/huggingface/huggingface_hub/issues/1008.
In moon-landing (internal repository):
- https://github.com/huggingface/moon-landing/blob/main/server/lib/Names.ts#L27
- https://github.com/huggingface/moon-landing/blob/main/server/views/components/NewRepoForm/NewRepoForm.svelte#L138
"""
if not isinstance(repo_id, str):
# Typically, a Path is not a repo_id
raise HFValidationError(f"Repo id must be a string, not {type(repo_id)}: '{repo_id}'.")
if repo_id.count("/") > 1:
raise HFValidationError(
"Repo id must be in the form 'repo_name' or 'namespace/repo_name':"
f" '{repo_id}'. Use `repo_type` argument if needed."
)
if not REPO_ID_REGEX.match(repo_id):
raise HFValidationError(
"Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are"
" forbidden, '-' and '.' cannot start or end the name, max length is 96:"
f" '{repo_id}'."
)
if "--" in repo_id or ".." in repo_id:
raise HFValidationError(f"Cannot have -- or .. in repo_id: '{repo_id}'.")
if repo_id.endswith(".git"):
raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.")
def smoothly_deprecate_use_auth_token(fn_name: str, has_token: bool, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Smoothly deprecate `use_auth_token` in the `huggingface_hub` codebase.
The long-term goal is to remove any mention of `use_auth_token` in the codebase in
favor of a unique and less verbose `token` argument. This will be done a few steps:
0. Step 0: methods that require a read-access to the Hub use the `use_auth_token`
argument (`str`, `bool` or `None`). Methods requiring write-access have a `token`
argument (`str`, `None`). This implicit rule exists to be able to not send the
token when not necessary (`use_auth_token=False`) even if logged in.
1. Step 1: we want to harmonize everything and use `token` everywhere (supporting
`token=False` for read-only methods). In order not to break existing code, if
`use_auth_token` is passed to a function, the `use_auth_token` value is passed
as `token` instead, without any warning.
a. Corner case: if both `use_auth_token` and `token` values are passed, a warning
is thrown and the `use_auth_token` value is ignored.
2. Step 2: Once it is release, we should push downstream libraries to switch from
`use_auth_token` to `token` as much as possible, but without throwing a warning
(e.g. manually create issues on the corresponding repos).
3. Step 3: After a transitional period (6 months e.g. until April 2023?), we update
`huggingface_hub` to throw a warning on `use_auth_token`. Hopefully, very few
users will be impacted as it would have already been fixed.
In addition, unit tests in `huggingface_hub` must be adapted to expect warnings
to be thrown (but still use `use_auth_token` as before).
4. Step 4: After a normal deprecation cycle (3 releases ?), remove this validator.
`use_auth_token` will definitely not be supported.
In addition, we update unit tests in `huggingface_hub` to use `token` everywhere.
This has been discussed in:
- https://github.com/huggingface/huggingface_hub/issues/1094.
- https://github.com/huggingface/huggingface_hub/pull/928
- (related) https://github.com/huggingface/huggingface_hub/pull/1064
"""
new_kwargs = kwargs.copy() # do not mutate input !
use_auth_token = new_kwargs.pop("use_auth_token", None) # remove from kwargs
if use_auth_token is not None:
if has_token:
warnings.warn(
"Both `token` and `use_auth_token` are passed to"
f" `{fn_name}` with non-None values. `token` is now the"
" preferred argument to pass a User Access Token."
" `use_auth_token` value will be ignored."
)
else:
# `token` argument is not passed and a non-None value is passed in
# `use_auth_token` => use `use_auth_token` value as `token` kwarg.
new_kwargs["token"] = use_auth_token
return new_kwargs

View File

@ -0,0 +1,66 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helpful utility functions and classes in relation to exploring API endpoints
with the aim for a user-friendly interface.
"""
import math
import re
from typing import TYPE_CHECKING
from ..repocard_data import ModelCardData
if TYPE_CHECKING:
from ..hf_api import ModelInfo
def _is_emission_within_threshold(model_info: "ModelInfo", minimum_threshold: float, maximum_threshold: float) -> bool:
"""Checks if a model's emission is within a given threshold.
Args:
model_info (`ModelInfo`):
A model info object containing the model's emission information.
minimum_threshold (`float`):
A minimum carbon threshold to filter by, such as 1.
maximum_threshold (`float`):
A maximum carbon threshold to filter by, such as 10.
Returns:
`bool`: Whether the model's emission is within the given threshold.
"""
if minimum_threshold is None and maximum_threshold is None:
raise ValueError("Both `minimum_threshold` and `maximum_threshold` cannot both be `None`")
if minimum_threshold is None:
minimum_threshold = -1
if maximum_threshold is None:
maximum_threshold = math.inf
card_data = getattr(model_info, "card_data", None)
if card_data is None or not isinstance(card_data, (dict, ModelCardData)):
return False
# Get CO2 emission metadata
emission = card_data.get("co2_eq_emissions", None)
if isinstance(emission, dict):
emission = emission["emissions"]
if not emission:
return False
# Filter out if value is missing or out of range
matched = re.search(r"\d+\.\d+|\d+", str(emission))
if matched is None:
return False
emission_value = float(matched.group(0))
return minimum_threshold <= emission_value <= maximum_threshold

View File

@ -0,0 +1,34 @@
# Taken from https://github.com/mlflow/mlflow/pull/10119
#
# DO NOT use this function for security purposes (e.g., password hashing).
#
# In Python >= 3.9, insecure hashing algorithms such as MD5 fail in FIPS-compliant
# environments unless `usedforsecurity=False` is explicitly passed.
#
# References:
# - https://github.com/mlflow/mlflow/issues/9905
# - https://github.com/mlflow/mlflow/pull/10119
# - https://docs.python.org/3/library/hashlib.html
# - https://github.com/huggingface/transformers/pull/27038
#
# Usage:
# ```python
# # Use
# from huggingface_hub.utils.insecure_hashlib import sha256
# # instead of
# from hashlib import sha256
#
# # Use
# from huggingface_hub.utils import insecure_hashlib
# # instead of
# import hashlib
# ```
import functools
import hashlib
import sys
_kwargs = {"usedforsecurity": False} if sys.version_info >= (3, 9) else {}
md5 = functools.partial(hashlib.md5, **_kwargs)
sha1 = functools.partial(hashlib.sha1, **_kwargs)
sha256 = functools.partial(hashlib.sha256, **_kwargs)

View File

@ -0,0 +1,182 @@
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Logging utilities."""
import logging
import os
from logging import (
CRITICAL, # NOQA
DEBUG, # NOQA
ERROR, # NOQA
FATAL, # NOQA
INFO, # NOQA
NOTSET, # NOQA
WARN, # NOQA
WARNING, # NOQA
)
from typing import Optional
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
_default_log_level = logging.WARNING
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())
def _get_default_logging_level():
"""
If `HF_HUB_VERBOSITY` env var is set to one of the valid choices return that as the new default level. If it is not
- fall back to `_default_log_level`
"""
env_level_str = os.getenv("HF_HUB_VERBOSITY", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option HF_HUB_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
)
return _default_log_level
def _configure_library_root_logger() -> None:
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(logging.StreamHandler())
library_root_logger.setLevel(_get_default_logging_level())
def _reset_library_root_logger() -> None:
library_root_logger = _get_library_root_logger()
library_root_logger.setLevel(logging.NOTSET)
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Returns a logger with the specified name. This function is not supposed
to be directly accessed by library users.
Args:
name (`str`, *optional*):
The name of the logger to get, usually the filename
Example:
```python
>>> from huggingface_hub import get_logger
>>> logger = get_logger(__file__)
>>> logger.set_verbosity_info()
```
"""
if name is None:
name = _get_library_name()
return logging.getLogger(name)
def get_verbosity() -> int:
"""Return the current level for the HuggingFace Hub's root logger.
Returns:
Logging level, e.g., `huggingface_hub.logging.DEBUG` and
`huggingface_hub.logging.INFO`.
<Tip>
HuggingFace Hub has following logging levels:
- `huggingface_hub.logging.CRITICAL`, `huggingface_hub.logging.FATAL`
- `huggingface_hub.logging.ERROR`
- `huggingface_hub.logging.WARNING`, `huggingface_hub.logging.WARN`
- `huggingface_hub.logging.INFO`
- `huggingface_hub.logging.DEBUG`
</Tip>
"""
return _get_library_root_logger().getEffectiveLevel()
def set_verbosity(verbosity: int) -> None:
"""
Sets the level for the HuggingFace Hub's root logger.
Args:
verbosity (`int`):
Logging level, e.g., `huggingface_hub.logging.DEBUG` and
`huggingface_hub.logging.INFO`.
"""
_get_library_root_logger().setLevel(verbosity)
def set_verbosity_info():
"""
Sets the verbosity to `logging.INFO`.
"""
return set_verbosity(INFO)
def set_verbosity_warning():
"""
Sets the verbosity to `logging.WARNING`.
"""
return set_verbosity(WARNING)
def set_verbosity_debug():
"""
Sets the verbosity to `logging.DEBUG`.
"""
return set_verbosity(DEBUG)
def set_verbosity_error():
"""
Sets the verbosity to `logging.ERROR`.
"""
return set_verbosity(ERROR)
def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is
disabled by default.
"""
_get_library_root_logger().propagate = False
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the
HuggingFace Hub's default handler to prevent double logging if the root
logger has been configured.
"""
_get_library_root_logger().propagate = True
_configure_library_root_logger()

View File

@ -0,0 +1,64 @@
"""Utilities to efficiently compute the SHA 256 hash of a bunch of bytes."""
from typing import BinaryIO, Optional
from .insecure_hashlib import sha1, sha256
def sha_fileobj(fileobj: BinaryIO, chunk_size: Optional[int] = None) -> bytes:
"""
Computes the sha256 hash of the given file object, by chunks of size `chunk_size`.
Args:
fileobj (file-like object):
The File object to compute sha256 for, typically obtained with `open(path, "rb")`
chunk_size (`int`, *optional*):
The number of bytes to read from `fileobj` at once, defaults to 1MB.
Returns:
`bytes`: `fileobj`'s sha256 hash as bytes
"""
chunk_size = chunk_size if chunk_size is not None else 1024 * 1024
sha = sha256()
while True:
chunk = fileobj.read(chunk_size)
sha.update(chunk)
if not chunk:
break
return sha.digest()
def git_hash(data: bytes) -> str:
"""
Computes the git-sha1 hash of the given bytes, using the same algorithm as git.
This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object
for more details.
Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the
pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of
the LFS file content when we want to compare LFS files.
Args:
data (`bytes`):
The data to compute the git-hash for.
Returns:
`str`: the git-hash of `data` as an hexadecimal string.
Example:
```python
>>> from huggingface_hub.utils.sha import git_hash
>>> git_hash(b"Hello, World!")
'b45ef6fec89518d314f546fd6c3025367b721684'
```
"""
# Taken from https://gist.github.com/msabramo/763200
# Note: no need to optimize by reading the file in chunks as we're not supposed to hash huge files (5MB maximum).
sha = sha1()
sha.update(b"blob ")
sha.update(str(len(data)).encode())
sha.update(b"\0")
sha.update(data)
return sha.hexdigest()

View File

@ -0,0 +1,264 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Utility helpers to handle progress bars in `huggingface_hub`.
Example:
1. Use `huggingface_hub.utils.tqdm` as you would use `tqdm.tqdm` or `tqdm.auto.tqdm`.
2. To disable progress bars, either use `disable_progress_bars()` helper or set the
environment variable `HF_HUB_DISABLE_PROGRESS_BARS` to 1.
3. To re-enable progress bars, use `enable_progress_bars()`.
4. To check whether progress bars are disabled, use `are_progress_bars_disabled()`.
NOTE: Environment variable `HF_HUB_DISABLE_PROGRESS_BARS` has the priority.
Example:
```py
>>> from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm
# Disable progress bars globally
>>> disable_progress_bars()
# Use as normal `tqdm`
>>> for _ in tqdm(range(5)):
... pass
# Still not showing progress bars, as `disable=False` is overwritten to `True`.
>>> for _ in tqdm(range(5), disable=False):
... pass
>>> are_progress_bars_disabled()
True
# Re-enable progress bars globally
>>> enable_progress_bars()
# Progress bar will be shown !
>>> for _ in tqdm(range(5)):
... pass
100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s]
```
Group-based control:
```python
# Disable progress bars for a specific group
>>> disable_progress_bars("peft.foo")
# Check state of different groups
>>> assert not are_progress_bars_disabled("peft"))
>>> assert not are_progress_bars_disabled("peft.something")
>>> assert are_progress_bars_disabled("peft.foo"))
>>> assert are_progress_bars_disabled("peft.foo.bar"))
# Enable progress bars for a subgroup
>>> enable_progress_bars("peft.foo.bar")
# Check if enabling a subgroup affects the parent group
>>> assert are_progress_bars_disabled("peft.foo"))
>>> assert not are_progress_bars_disabled("peft.foo.bar"))
# No progress bar for `name="peft.foo"`
>>> for _ in tqdm(range(5), name="peft.foo"):
... pass
# Progress bar will be shown for `name="peft.foo.bar"`
>>> for _ in tqdm(range(5), name="peft.foo.bar"):
... pass
100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s]
```
"""
import io
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterator, Optional, Union
from tqdm.auto import tqdm as old_tqdm
from ..constants import HF_HUB_DISABLE_PROGRESS_BARS
# The `HF_HUB_DISABLE_PROGRESS_BARS` environment variable can be True, False, or not set (None),
# allowing for control over progress bar visibility. When set, this variable takes precedence
# over programmatic settings, dictating whether progress bars should be shown or hidden globally.
# Essentially, the environment variable's setting overrides any code-based configurations.
#
# If `HF_HUB_DISABLE_PROGRESS_BARS` is not defined (None), it implies that users can manage
# progress bar visibility through code. By default, progress bars are turned on.
progress_bar_states: Dict[str, bool] = {}
def disable_progress_bars(name: Optional[str] = None) -> None:
"""
Disable progress bars either globally or for a specified group.
This function updates the state of progress bars based on a group name.
If no group name is provided, all progress bars are disabled. The operation
respects the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable's setting.
Args:
name (`str`, *optional*):
The name of the group for which to disable the progress bars. If None,
progress bars are disabled globally.
Raises:
Warning: If the environment variable precludes changes.
"""
if HF_HUB_DISABLE_PROGRESS_BARS is False:
warnings.warn(
"Cannot disable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=0` is set and has priority."
)
return
if name is None:
progress_bar_states.clear()
progress_bar_states["_global"] = False
else:
keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")]
for key in keys_to_remove:
del progress_bar_states[key]
progress_bar_states[name] = False
def enable_progress_bars(name: Optional[str] = None) -> None:
"""
Enable progress bars either globally or for a specified group.
This function sets the progress bars to enabled for the specified group or globally
if no group is specified. The operation is subject to the `HF_HUB_DISABLE_PROGRESS_BARS`
environment setting.
Args:
name (`str`, *optional*):
The name of the group for which to enable the progress bars. If None,
progress bars are enabled globally.
Raises:
Warning: If the environment variable precludes changes.
"""
if HF_HUB_DISABLE_PROGRESS_BARS is True:
warnings.warn(
"Cannot enable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=1` is set and has priority."
)
return
if name is None:
progress_bar_states.clear()
progress_bar_states["_global"] = True
else:
keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")]
for key in keys_to_remove:
del progress_bar_states[key]
progress_bar_states[name] = True
def are_progress_bars_disabled(name: Optional[str] = None) -> bool:
"""
Check if progress bars are disabled globally or for a specific group.
This function returns whether progress bars are disabled for a given group or globally.
It checks the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable first, then the programmatic
settings.
Args:
name (`str`, *optional*):
The group name to check; if None, checks the global setting.
Returns:
`bool`: True if progress bars are disabled, False otherwise.
"""
if HF_HUB_DISABLE_PROGRESS_BARS is True:
return True
if name is None:
return not progress_bar_states.get("_global", True)
while name:
if name in progress_bar_states:
return not progress_bar_states[name]
name = ".".join(name.split(".")[:-1])
return not progress_bar_states.get("_global", True)
class tqdm(old_tqdm):
"""
Class to override `disable` argument in case progress bars are globally disabled.
Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324.
"""
def __init__(self, *args, **kwargs):
name = kwargs.pop("name", None) # do not pass `name` to `tqdm`
if are_progress_bars_disabled(name):
kwargs["disable"] = True
super().__init__(*args, **kwargs)
def __delattr__(self, attr: str) -> None:
"""Fix for https://github.com/huggingface/huggingface_hub/issues/1603"""
try:
super().__delattr__(attr)
except AttributeError:
if attr != "_lock":
raise
@contextmanager
def tqdm_stream_file(path: Union[Path, str]) -> Iterator[io.BufferedReader]:
"""
Open a file as binary and wrap the `read` method to display a progress bar when it's streamed.
First implemented in `transformers` in 2019 but removed when switched to git-lfs. Used in `huggingface_hub` to show
progress bar when uploading an LFS file to the Hub. See github.com/huggingface/transformers/pull/2078#discussion_r354739608
for implementation details.
Note: currently implementation handles only files stored on disk as it is the most common use case. Could be
extended to stream any `BinaryIO` object but we might have to debug some corner cases.
Example:
```py
>>> with tqdm_stream_file("config.json") as f:
>>> requests.put(url, data=f)
config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
```
"""
if isinstance(path, str):
path = Path(path)
with path.open("rb") as f:
total_size = path.stat().st_size
pbar = tqdm(
unit="B",
unit_scale=True,
total=total_size,
initial=0,
desc=path.name,
)
f_read = f.read
def _inner_read(size: Optional[int] = -1) -> bytes:
data = f_read(size)
pbar.update(len(data))
return data
f.read = _inner_read # type: ignore
yield f
pbar.close()