主页

索引

模块索引

搜索页面

7.1.3. Hub Python Library

安装

pip install 'huggingface_hub["cli"]'
pip install --upgrade huggingface_hub[inference]    // async http (or use pip install aiohttp)

基本操作

Login

cli:

huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN

python:

from huggingface_hub import login
login()

Create a repository

create a private repo:

from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id="super-cool-model")

create a private repo:

from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id="super-cool-model", private=True)

Upload files

from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
    path_or_fileobj="/home/lysandre/dummy-test/README.md",
    path_in_repo="README.md",
    repo_id="lysandre/test-model",
)

环境变量

预训练模型本地缓存地址:

1. default: HUGGINGFACE_HUB_CACHE or TRANSFORMERS_CACHE.
2. HF_HOME
3. XDG_CACHE_HOME + /huggingface

默认: ~/.cache/huggingface/hub

通用参数:

HF_ENDPOINT: https://huggingface.co
HF_HOME: ~/.cache/huggingface
HUGGINGFACE_HUB_CACHE: ~/.cache/huggingface/hub
HUGGINGFACE_ASSETS_CACHE: ~/.cache/huggingface/assets
HUGGING_FACE_HUB_TOKEN: $HF_HOME/token

HF_HUB_OFFLINE:

如果设置,则在尝试获取文件时不会进行任何HTTP调用

Offline mode 离线模式

启用离线模式:

设置环境变量 TRANSFORMERS_OFFLINE=1

设置环境变量 HF_DATASETS_OFFLINE=1: Add Datasets to your offline training workflow

示例:

python examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
   --dataset_name wmt16 --dataset_config ro-en ...


在脱机实例中运行此相同程序:
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \
python examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
        --dataset_name wmt16 --dataset_config ro-en ...

获取模型和分词器以离线使用

  1. Model Hub 直接下载文件

  2. Use the PreTrainedModel.from_pretrained() and PreTrainedModel.save_pretrained() workflow:

    a. Download your files ahead of time with `PreTrainedModel.from_pretrained()`:
    
            from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
            tokenizer = AutoTokenizer.from_pretrained("bigscience/T0_3B")
            model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0_3B")
    
    b. Save your files to a specified directory with `PreTrainedModel.save_pretrained()`:
    
            tokenizer.save_pretrained("./your/path/bigscience_t0")
            model.save_pretrained("./your/path/bigscience_t0")
    
    c. Now when you’re offline, reload your files with `PreTrainedModel.from_pretrained()` from the specified directory:
    
            tokenizer = AutoTokenizer.from_pretrained("./your/path/bigscience_t0")
            model = AutoModel.from_pretrained("./your/path/bigscience_t0")
    
  3. Programmatically download files with the huggingface_hub library:

    a. Install the huggingface_hub library in your virtual environment:
            python -m pip install huggingface_hub
    
    b. Use the `hf_hub_download()` function to download a file to a specific path.
            from huggingface_hub import hf_hub_download
            hf_hub_download(repo_id="bigscience/T0_3B", filename="config.json", cache_dir="./your/path/bigscience_t0")
    
    c. Once your file is downloaded and locally cached, specify it’s local path to load and use it:
            from transformers import AutoConfig
            config = AutoConfig.from_pretrained("./your/path/bigscience_t0/config.json")
    

Download files from the Hub

Download and cache a single file:

>>> from huggingface_hub import hf_hub_download
>>> hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json")
'/root/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a...fade/config.json'

# Download from a dataset
hf_hub_download(repo_id="google/fleurs", filename="fleurs.py", repo_type="dataset")

# Download from the `v1.0` tag
hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json", revision="v1.0")

Download and cache an entire repository:

>>> from huggingface_hub import snapshot_download
>>> snapshot_download(repo_id="lysandre/arxiv-nlp")
'/home/lysandre/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a...fade'

Download files to a local folder:

在使用hf_hub_download() and snapshot_download()函数时可以借用如下参数把文件存储到指定目录:
cache_dir:
filename:
local_dir:
local_dir_use_symlinks=True


Example:
        filename="data/train.csv"
        local_dir="path/to/folder"
        => the returned filepath will be "path/to/folder/data/train.csv".

Upload files to the Hub

Upload a file:

from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
    path_or_fileobj="/path/to/local/folder/README.md",
    path_in_repo="README.md",
    repo_id="username/test-dataset",
    repo_type="dataset",
)

Upload a folder:

from huggingface_hub import HfApi
api = HfApi()

api.upload_folder(
    folder_path="/path/to/local/space",
    repo_id="username/my-cool-space",
    repo_type="space",
)

Non-blocking upload(使用 run_as_future 参数):

>> from huggingface_hub import HfApi
>> api = HfApi()
>> future = api.upload_folder( # Upload in the background (non-blocking action)
..     repo_id="username/my-model",
..     folder_path="checkpoints-001",
..     run_as_future=True,
.. )

>> future
Future(...)

>> future.done()
False

>> future.result() # Wait for the upload to complete (blocking action)
...

Upload a folder by chunks:

upload_folder(
    folder_path="local/checkpoints",
    repo_id="username/my-dataset",
    repo_type="dataset",
    multi_commits=True,
    multi_commits_verbose=True,
)

HfFileSystem

  • Interact with the Hub through the Filesystem API

Usage

from huggingface_hub import HfFileSystem
fs = HfFileSystem()

# List all files in a directory
fs.ls("datasets/my-username/my-dataset-repo/data", detail=False)

# List all ".csv" files in a repo
fs.glob("datasets/my-username/my-dataset-repo/**.csv")

# Read a remote file
with fs.open("datasets/my-username/my-dataset-repo/data/train.csv", "r") as f:
    train_data = f.readlines()

# Read the content of a remote file as a string
train_data = fs.read_text("datasets/my-username/my-dataset-repo/data/train.csv", revision="dev")

# Write a remote file
with fs.open("datasets/my-username/my-dataset-repo/data/validation.csv", "w") as f:
    f.write("text,label")
    f.write("Fantastic movie!,good")

Integrations

文件格式:

hf://[<repo_type_prefix>]<repo_id>[@<revision>]/<path/in/repo>

Reading/writing a Pandas DataFrame from/to a Hub repository:

import pandas as pd

# Read a remote CSV file into a dataframe
df = pd.read_csv("hf://datasets/my-username/my-dataset-repo/train.csv")

# Write a dataframe to a remote CSV file
df.to_csv("hf://datasets/my-username/my-dataset-repo/test.csv")

Querying (remote) Hub files with DuckDB:

from huggingface_hub import HfFileSystem
import duckdb

fs = HfFileSystem()
duckdb.register_filesystem(fs)
# Query a remote file and get the result back as a dataframe
fs_query_file = "hf://datasets/my-username/my-dataset-repo/data_dir/data.parquet"
df = duckdb.query(f"SELECT * FROM '{fs_query_file}' LIMIT 10").df()

Authentication

from huggingface_hub import HfFileSystem
fs = HfFileSystem(token=token)

Repository

Repo creation and deletion

create_repo():

>> from huggingface_hub import create_repo
>> create_repo("lysandre/test-model")
'https://huggingface.co/lysandre/test-model'


>> create_repo("lysandre/test-dataset", repo_type="dataset")
>> create_repo("lysandre/test-private", private=True)

delete_repo():

delete_repo(repo_id="lysandre/my-corrupted-dataset", repo_type="dataset")

Clone a repository (only for Spaces):

>> from huggingface_hub import duplicate_space
>> duplicate_space("multimodalart/dreambooth-training", private=False)
RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...)

Other OPs

from huggingface_hub import create_branch, create_tag
from huggingface_hub import list_repo_refs
from huggingface_hub import update_repo_visibility
from huggingface_hub import move_repo
from huggingface_hub import whoami

from huggingface_hub import InferenceClient

Manage a local copy of your repository

Use a local repository:

from huggingface_hub import Repository
repo = Repository(local_dir="<path>/<to>/<folder>")

Clone:

from huggingface_hub import Repository
repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60")


repo = Repository(local_dir="huggingface-hub", clone_from="https://huggingface.co/facebook/wav2vec2-large-960h-lv60")

Inference

There are several services you can connect to:

Inference API:
        a service that allows you to run accelerated inference
                on Hugging Face’s infrastructure for free.
        This service is a fast way to get started, test different models,
                and prototype AI products.
Inference Endpoints:
        a product to easily deploy models to production.
        Inference is run by Hugging Face in a dedicated,
                fully managed infrastructure on a cloud provider of your choice.

表格:currently supported tasks:

+------------+----------------------------------+----+----------------------------------+
| Domain     | Task                             | Sup| Method                           |
+============+==================================+====+==================================+
| Audio      | Audio Classification             | ✅ | audio_classification()           |
+------------+----------------------------------+----+----------------------------------+
|            | Automatic Speech Recognition     | ✅ | automatic_speech_recognition()   |
+------------+----------------------------------+----+----------------------------------+
|            | Text-to-Speech                   | ✅ | text_to_speech()                 |
+------------+----------------------------------+----+----------------------------------+
| CV         | Image Classification             | ✅ | image_classification()           |
+------------+----------------------------------+----+----------------------------------+
|            | Image Segmentation               | ✅ | image_segmentation()             |
+------------+----------------------------------+----+----------------------------------+
|            | Image-to-Image                   | ✅ | image_to_image()                 |
+------------+----------------------------------+----+----------------------------------+
|            | Image-to-Text                    | ✅ | image_to_text()                  |
+------------+----------------------------------+----+----------------------------------+
|            | Object Detection                 |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Text-to-Image                    | ✅ | text_to_image()                  |
+------------+----------------------------------+----+----------------------------------+
|            | Zero-Shot-Image-Classification   | ✅ | zero_shot_image_classification() |
+------------+----------------------------------+----+----------------------------------+
| Multimodal | Documentation Question Answering |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Visual Question Answering        |    |                                  |
+------------+----------------------------------+----+----------------------------------+
| NLP        | Conversational                   | ✅ | conversational()                 |
+------------+----------------------------------+----+----------------------------------+
|            | Feature Extraction               | ✅ | feature_extraction()             |
+------------+----------------------------------+----+----------------------------------+
|            | Fill Mask                        |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Question Answering               |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Sentence Similarity              | ✅ | sentence_similarity()            |
+------------+----------------------------------+----+----------------------------------+
|            | Summarization                    | ✅ | summarization()                  |
+------------+----------------------------------+----+----------------------------------+
|            | Table Question Answering         |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Text Classification              |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Text Generation                  | ✅ | text_generation()                |
+------------+----------------------------------+----+----------------------------------+
|            | Token Classification             |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Translation                      |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Zero Shot Classification         |    |                                  |
+------------+----------------------------------+----+----------------------------------+
| Tabular    | Tabular Classification           |    |                                  |
+------------+----------------------------------+----+----------------------------------+
|            | Tabular Regression               |    |                                  |
+------------+----------------------------------+----+----------------------------------+

Inference Client

生成client命令:

from huggingface_hub import InferenceClient

# default token
client = InferenceClient(

# Authentication
client = InferenceClient(token='hf_.....')

# Using a specific model
client = InferenceClient(model="prompthero/openjourney-v4")

# Using a specific URL
client = InferenceClient(model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if")

使用示例:

from huggingface_hub import InferenceClient
client = InferenceClient()
image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
image.save("tiger.jpg")

备注

上面只是其中一个示例,这儿 有各种使用示例

Custom requests

from huggingface_hub import InferenceClient
client = InferenceClient()
response = client.post(
        json={"inputs": "An astronaut riding a horse on the moon."},
        model="stabilityai/stable-diffusion-2-1"
)
response.content # raw bytes

Async Inference Client

安装:

pip install --upgrade huggingface_hub[inference]
# or
# pip install aiohttp

生成async client命令:

from huggingface_hub import AsyncInferenceClient
client = AsyncInferenceClient(token='hf_.....')

InferenceAPI

示例:

from huggingface_hub.inference_api import InferenceApi

# Mask-fill example
inference = InferenceApi("bert-base-uncased")
inference(inputs="The goal of life is [MASK].")

# Question Answering example
inference = InferenceApi("deepset/roberta-base-squad2")
inputs = {
    "question": "What's my name?",
    "context": "My name is Clara and I live in Berkeley.",
}
inference(inputs)

# Zero-shot example
inference = InferenceApi("typeform/distilbert-base-uncased-mnli")
inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
params = {"candidate_labels": ["refund", "legal", "faq"]}
inference(inputs, params)

# Overriding configured task
inference = InferenceApi("bert-base-uncased", task="feature-extraction")

# Text-to-image

# Using a specific model
client = InferenceClient(model="prompthero/openjourney-v4")



# Return as raw response to parse the output yourself
inference = InferenceApi("mio/amadeus")
response = inference("hello world", raw_response=True)
response.headers
response.content # raw bytes from server

Cache

Understand caching

  • The Hugging Face Hub cache-system is designed to be the central cache shared across libraries that depend on the Hub.

The caching system is designed as follows:

<CACHE_DIR>
├─ <MODELS>
├─ <DATASETS>
├─ <SPACES>

示例:
<CACHE_DIR>
├─ models--julien-c--EsperBERTo-small
├─ models--lysandrejik--arxiv-nlp
├─ models--bert-base-cased
├─ datasets--glue
├─ datasets--huggingface--DataMeasurementsFiles
├─ spaces--dalle-mini--dalle-mini
  • <CACHE_DIR> is customizable with the cache_dir argument on all methods, or by specifying either HF_HOME or HUGGINGFACE_HUB_CACHE environment variable.

Caching assets

  • 除了从 Hub 缓存文件外,下游库通常需要缓存与 HF 相关但未直接被 huggingface_hub 处理的其他文件 (例如:从 GitHub 下载的文件、预处理的数据、日志,…)

  • 为了缓存这些文件,调用 assets 使用 huggingface_hub 工具的 cached_assets_path() 方法来管理缓存:

    from huggingface_hub import cached_assets_path
    
    assets_path = cached_assets_path(
            library_name="datasets",
            namespace="SQuAD",
            subfolder="download"
    )
    # Do anything you like in your assets folder !
    something_path = assets_path / "something.json"
    

Scan your cache

Scan cache from the terminal

➜ 示例:

$ huggingface-cli scan-cache
REPO ID             REPO TYPE SIZE ON DISK LAST_MODIFIED REFS                LOCAL PATH
----------------    --------- ------------ ------------- ------------------- -------------
glue                dataset         116.3K 4 days ago    2.4.0, main, 1.17.0 ~/.cache/huggingface/hub/datasets--glue
google/fleurs       dataset          64.9M 1 week ago    refs/pr/1, main     ~/.cache/huggingface/hub/datasets--google--fleurs
Jean-Baptiste/ner   model           441.0M 16 hours ago  main                ~/.cache/huggingface/hub/models--Jean-Baptiste--ner
bert-base-cased     model             1.9G 2 years ago                       ~/.cache/huggingface/hub/models--bert-base-cased
t5-base             model            10.1K 3 months ago  main                ~/.cache/huggingface/hub/models--t5-base
t5-small            model           970.7M 3 days ago    refs/pr/1, main     ~/.cache/huggingface/hub/models--t5-small

Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G.
Got 1 warning(s) while scanning. Use -vvv to print details.

Example:

➜ huggingface-cli scan-cache -v
➜ eval "huggingface-cli scan-cache -v" | grep "t5-small"

Scan cache from Python

Clean your cache

Clean cache from the terminal

Using the TUI (Terminal User Interface):

// 安装
pip install huggingface_hub["cli"]

// 使用
huggingface-cli delete-cache

huggingface-cli delete-cache --dir ~/.cache/huggingface/hub

Clean cache from Python

from huggingface_hub import scan_cache_dir

delete_strategy = scan_cache_dir().delete_revisions(
    "81fd1d6e7847c99f5862c9fb81387956d99ec7aa"
    "e2983b237dccf3ab4937c97fa717319a9ca1a96d",
    "6c0e6080953db56375760c0471a8c5f2929baf11",
)
print("Will free " + delete_strategy.expected_freed_size_str)

delete_strategy.execute()

Model Card

Load a Model Card from the Hub

from huggingface_hub import ModelCard
card = ModelCard.load('nateraw/vit-base-beans')

=>
card.data: Returns a ModelCardData instance with the model card’s metadata.
        Call .to_dict() on this instance to get the representation as a dictionary.
card.text: Returns the text of the card, excluding the metadata header.
card.content: Returns the text content of the card, including the metadata header.

=>
card.data.to_dict()['language']
card.data.to_dict()['license']
...

Create Model Cards

很多种,最简单的方式:

content = """
---
language: en
license: mit
---

# My Model Card
"""

card = ModelCard(content)

Share Model Cards

create a new repo called ‘hf-hub-modelcards-pr-test’:

from huggingface_hub import whoami, create_repo

user = whoami()['name']
repo_id = f'{user}/hf-hub-modelcards-pr-test'
url = create_repo(repo_id, exist_ok=True)

create a card from the default template:

card_data = ModelCardData(language='en', license='mit', library_name='keras')
card = ModelCard.from_template(
    card_data,
    model_id='my-cool-model',
    model_description="this model does this and that",
    developers="Nate Raw",
    repo="https://github.com/huggingface/huggingface_hub",
)

push that up to the hub:

>> card.push_to_hub(repo_id)
or push a card as a pull request
>> card.push_to_hub(repo_id, create_pr=True)

Include Evaluation Results

card_data = ModelCardData(
    language='en',
    license='mit',
    model_name='my-cool-model',
    eval_results = EvalResult(
        task_type='image-classification',
        dataset_type='beans',
        dataset_name='Beans',
        metric_type='accuracy',
        metric_value=0.7
    )
)

card = ModelCard.from_template(card_data)
print(card.data)

card.data should look like this:

language: en
license: mit
model-index:
- name: my-cool-model
  results:
  - task:
      type: image-classification
    dataset:
      name: Beans
      type: beans
    metrics:
    - type: accuracy
      value: 0.7

Integrate Library

There are four main ways to integrate a library with the Hub:

1. Push to Hub:
        implement a method to upload a model to the Hub.
        This includes the model weights, model card and necessary data(for example, training logs).
        This method is often called `push_to_hub`().
2. Download from Hub:
        implement a method to load a model from the Hub.
        The method should download the model configuration/weights and load the model.
        This method is often called `from_pretrained` or `load_from_hub`().
3. Inference API:
        use our servers to run inference on models supported by your library for free.
4. Widgets:
        display a widget on the landing page of your models on the Hub.
        It allows users to quickly try a model from the browser.

A flexible approach: helpers

  • The first approach to integrate a library to the Hub is to actually implement the push_to_hub and from_pretrained methods by yourself.

from_pretrained:

def from_pretrained(model_id: str) -> MyModelClass:
   # Download model from Hub
   cached_model = hf_hub_download(
      repo_id=repo_id,
      filename="model.pkl",
      library_name="fastai",
      library_version=get_fastai_version(),
   )

   # Load model
    return load_model(cached_model)

push_to_hub:

def push_to_hub(model: MyModelClass, repo_name: str) -> None:
   api = HfApi()

   # Create repo if not existing yet and get the associated repo_id
   repo_id = api.create_repo(repo_name, exist_ok=True)

   # Save all files in a temporary directory and push them in a single commit
   with TemporaryDirectory() as tmpdir:
      tmpdir = Path(tmpdir)

      # Save weights
      save_model(model, tmpdir / "model.safetensors")

      # Generate model card
      card = generate_model_card(model)
      (tmpdir / "README.md").write_text(card)

      # Save logs
      # Save figures
      # Save evaluation metrics
      # ...

      # Push to hub
      return api.upload_folder(repo_id=repo_id, folder_path=tmpdir)

Limitations

when loading files from the Hub, it is common to offer parameters like:

token: to download from a private repo
revision: to download from a specific branch
cache_dir: to cache files in a specific directory
...

when pushing models, similar parameters are supported:

commit_message: custom commit message
private: create a private repo if missing
create_pr: create a PR instead of pushing to main
...

备注

All of these parameters can be added to the implementations we saw above and passed to the huggingface_hub methods. However, if a parameter changes or a new feature is added, you will need to update your package. Supporting those parameters also means more documentation to maintain on your side.

A more complex approach: class inheritance

  • A Mixin is a class that is meant to extend an existing class with a set of specific features using multiple inheritance.

  • Example: huggingface_hub provides its own mixin, the ModelHubMixin.

  • The ModelHubMixin class implements 3 public methods (push_to_hub, save_pretrained and from_pretrained).

步骤:

1. Make your Model class inherit from ModelHubMixin.
2. Implement the private methods:
        _save_pretrained()
        _from_pretrained()

A concrete example: PyTorch

How to use it

示例:

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

# 1. Define your Pytorch model exactly the same way you are used to
class MyModel(nn.Module, PyTorchModelHubMixin): # multiple inheritance
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.rand(3, 4))
        self.linear = nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param)
model = MyModel()

# 2. (optional) Save model to local directory
model.save_pretrained("path/to/my-awesome-model")

# 3. Push model weights to the Hub
model.push_to_hub("my-awesome-model")

# 4. Initialize model from the Hub
model = MyModel.from_pretrained("username/my-awesome-model")
Implementation
  1. First, inherit your class from ModelHubMixin:

    from huggingface_hub import ModelHubMixin
    
    class PyTorchModelHubMixin(ModelHubMixin):
       (...)
    
  2. Implement the _save_pretrained method:

    from huggingface_hub import ModelCard, ModelCardData
    
    class PyTorchModelHubMixin(ModelHubMixin):
       (...)
    
       def _save_pretrained(self, save_directory: Path):
          """Generate Model Card and save weights from a Pytorch model to a local directory."""
          model_card = ModelCard.from_template(
             card_data=ModelCardData(
                license='mit',
                library_name="pytorch",
                ...
             ),
             model_summary=...,
             model_type=...,
             ...
          )
          (save_directory / "README.md").write_text(str(model))
          torch.save(obj=self.module.state_dict(), f=save_directory / "pytorch_model.bin")
    
  3. Implement the _from_pretrained method:

    class PyTorchModelHubMixin(ModelHubMixin):
       (...)
    
       @classmethod # Must be a classmethod!
       def _from_pretrained(
          cls,
          *,
          model_id: str,
          revision: str,
          cache_dir: str,
          force_download: bool,
          proxies: Optional[Dict],
          resume_download: bool,
          local_files_only: bool,
          token: Union[str, bool, None],
          map_location: str = "cpu", # additional argument
          strict: bool = False, # additional argument
          **model_kwargs,
       ):
          """Load Pytorch pretrained weights and return the loaded model."""
          if os.path.isdir(model_id): # Can either be a local directory
             print("Loading weights from local directory")
             model_file = os.path.join(model_id, "pytorch_model.bin")
          else: # Or a model on the Hub
             model_file = hf_hub_download( # Download from the hub, passing same input args
                repo_id=model_id,
                filename="pytorch_model.bin",
                revision=revision,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                token=token,
                local_files_only=local_files_only,
             )
    
          # Load model and return - custom logic depending on your framework
          model = cls(**model_kwargs)
          state_dict = torch.load(model_file, map_location=torch.device(map_location))
          model.load_state_dict(state_dict, strict=strict)
          model.eval()
          return model
    

Quick comparison

  1. 用户体验:

    - helpers:需要加载并推送模型,代码较复杂
    - ModelHubMixin:直接继承模型类,用from_pretrained和push_to_hub方法,更简单直观
    
  2. 灵活性:

    - helpers:非常灵活,可以完全自定义实现
    - ModelHubMixin:不太灵活,需要框架有模型类
    
  3. 维护:

    - helpers:需要自行维护配置、新特性等,解决用户报告的issue
    - ModelHubMixin:由huggingface_hub包处理大部分交互,维护更少
    
  4. 文档/类型注解:

    - helpers:需要手动编写文档和注解
    - ModelHubMixin:部分由huggingface_hub自动生成
    

主页

索引

模块索引

搜索页面