有没有可能不首先在本地持久化,就可以从GCS bucket URL加载预训练的Pytorch模型?

2024-09-29 21:59:54 发布

您现在位置:Python中文网/ 问答频道 /正文

我是在Google数据流的背景下提出这个问题的,但也是一般性的。在

使用PyTorch,我可以引用包含多个文件的本地目录,这些文件组成了一个预先训练的模型。我碰巧使用的是Roberta模型,但是其他人的界面是一样的。在

ls some-directory/
      added_tokens.json
      config.json             
      merges.txt              
      pytorch_model.bin       
      special_tokens_map.json vocab.json
^{pr2}$

但是,我的预训练模型存储在GCS存储桶中。我们称之为gs://my-bucket/roberta/。在

在googledataflow中加载这个模型的上下文中,我试图保持无状态并避免持久化到磁盘上,所以我倾向于直接从GCS获取这个模型。据我所知,PyTorch通用接口方法from_pretrained()可以采用本地目录或URL的字符串表示。但是,我似乎无法从GCS URL加载模型。在

# this fails
model = RobertaModel.from_pretrained('gs://my-bucket/roberta/')
# ValueError: unable to parse gs://mahmed_bucket/roberta-base as a URL or as a local path

如果我尝试使用目录blob的公共https URL,它也会失败,尽管这可能是由于lack of authentication造成的,因为python环境中引用的可以创建客户端的凭据不会转换为https://storage.googleapis的公共请求

# this fails, probably due to auth
bucket = gcs_client.get_bucket('my-bucket')
directory_blob = bucket.blob(prefix='roberta')
model = RobertaModel.from_pretrained(directory_blob.public_url)
# ValueError: No JSON object could be decoded

# and for good measure, it also fails if I append a trailing /
model = RobertaModel.from_pretrained(directory_blob.public_url + '/')
# ValueError: No JSON object could be decoded

我理解GCS doesn't actually have subdirectories,它实际上只是bucket名称下的一个平面名称空间。然而,我似乎被认证的必要性和一个不说话的PyTorch所阻碍。在

我可以通过先在本地保存文件来解决这个问题。在

from pytorch_transformers import RobertaModel
from google.cloud import storage
import tempfile

local_dir = tempfile.mkdtemp()
gcs = storage.Client()
bucket = gcs.get_bucket(bucket_name)
blobs = bucket.list_blobs(prefix=blob_prefix)
for blob in blobs:
    blob.download_to_filename(local_dir + '/' + os.path.basename(blob.name))
model = RobertaModel.from_pretrained(local_dir)

但这看起来像是一次黑客攻击,我一直在想我一定是错过了什么。当然,有一种方法可以保持无状态,而不必依赖磁盘持久性!在

  • 那么有没有一种方法可以加载存储在GCS中的预训练模型?在
  • 在这个上下文中执行公共URL请求时,有没有一种方法可以进行身份验证?在
  • 即使有认证的方法,子目录的不存在是否仍然是一个问题?在

谢谢你的帮助!我也很高兴有人指出任何重复的问题,因为我肯定找不到任何问题。在


编辑和澄清

  • 我的Python会话已经通过GCS的身份验证,这就是为什么我能够在本地下载blob文件,然后使用load_frompretrained()

    指向本地目录
  • load_frompretrained()需要一个目录引用,因为它需要问题顶部列出的所有文件,而不仅仅是pytorch-model.bin

  • 为了澄清问题2,我想知道是否有某种方法可以给PyTorch方法一个嵌入了加密凭证的请求URL或类似的东西。有点长,但我想确保我没有错过任何东西。

  • 为了澄清问题3(除了下面对一个答案的评论),即使有一种方法可以在我不知道的URL中嵌入凭据,我仍然需要引用一个目录而不是一个blob,我不知道GCS子目录是否会被识别,因为GCS中的子目录(如googledocs所述)是一种错觉,它们并不代表真正的目录结构。所以我觉得这个问题是无关紧要的,或者至少被第2个问题挡住了,但这是一个我追求的线索,所以我仍然很好奇。


Tags: 文件方法from模型目录jsonurlmodel
3条回答

主要编辑:

您可以在Dataflow worker上安装wheel文件,还可以使用worker temp存储在本地保存二进制文件!在

事实上(目前截止到2019年11月),你不能通过提供一个--requirements参数来做到这一点。相反,您必须像这样使用setup.py。假设CAPS中的任何常量都是在别处定义的。在

REQUIRED_PACKAGES = [
    'torch==1.3.0',
    'pytorch-transformers==1.2.0',
]

setup(
    name='project_dir',
    version=VERSION,
    packages=find_packages(),
    install_requires=REQUIRED_PACKAGES)

运行脚本

^{pr2}$

在这个工作中,我们从GCS下载并在一个定制的数据流操作符的上下文中使用这个模型。为了方便起见,我们将一些实用程序方法包装在一个单独的模块中(这对于避免数据流依赖性上载很重要),并将它们导入到自定义运算符的本地范围,而不是全局范围。在

class AddColumn(beam.DoFn):
    PRETRAINED_MODEL = 'gs://my-bucket/blah/roberta-model-files'

    def get_model_tokenizer_wrapper(self):
        import shutil
        import tempfile
        import dataflow_util as util
        try:
            return self.model_tokenizer_wrapper
        except AttributeError:
            tmp_dir = tempfile.mkdtemp() + '/'
            util.download_tree(self.PRETRAINED_MODEL, tmp_dir)
            model, tokenizer = util.create_model_and_tokenizer(tmp_dir)
            model_tokenizer_wrapper = util.PretrainedPyTorchModelWrapper(
                model, tokenizer)
            shutil.rmtree(tmp_dir)
            self.model_tokenizer_wrapper = model_tokenizer_wrapper
            logging.info(
                'Successfully created PretrainedPyTorchModelWrapper')
            return self.model_tokenizer_wrapper

    def process(self, elem):
        model_tokenizer_wrapper = self.get_model_tokenizer_wrapper()

        # And now use that wrapper to process your elem however you need.
        # Note that when you read from BQ your elements are dictionaries
        # of the column names and values for each BQ row.

代码库中独立模块中的实用程序函数。在我们的项目根中,这是在dataflow\u util中/初始py但你不必那样做。在

from contextlib import closing
import logging

import apache_beam as beam
import numpy as np
from pytorch_transformers import RobertaModel, RobertaTokenizer
import torch

class PretrainedPyTorchModelWrapper():
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

def download_tree(gcs_dir, local_dir):
    gcs = beam.io.gcp.gcsio.GcsIO()
    assert gcs_dir.endswith('/')
    assert local_dir.endswith('/')
    for entry in gcs.list_prefix(gcs_dir):
        download_file(gcs, gcs_dir, local_dir, entry)


def download_file(gcs, gcs_dir, local_dir, entry):
    rel_path = entry[len(gcs_dir):]
    dest_path = local_dir + rel_path
    logging.info('Downloading %s', dest_path)
    with closing(gcs.open(entry)) as f_read:
        with open(dest_path, 'wb') as f_write:
            # Download the file in chunks to avoid requiring large amounts of
            # RAM when downloading large files.
            while True:
                file_data_chunk = f_read.read(
                    beam.io.gcp.gcsio.DEFAULT_READ_BUFFER_SIZE)
                if len(file_data_chunk):
                    f_write.write(file_data_chunk)
                else:
                    break


def create_model_and_tokenizer(local_model_path_str):
    """
    Instantiate transformer model and tokenizer

      :param local_model_path_str: string representation of the local path 
             to the directory containing the pretrained model
      :return: model, tokenizer
    """
    model_class, tokenizer_class = (RobertaModel, RobertaTokenizer)

    # Load the pretrained tokenizer and model
    tokenizer = tokenizer_class.from_pretrained(local_model_path_str)
    model = model_class.from_pretrained(local_model_path_str)

    return model, tokenizer

伙计们,你们有了!更多详细信息可在此处找到:https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/


我发现,整个问题链都是无关紧要的,因为Dataflow只允许您在worker上安装源分发包,这意味着您无法实际安装PyTorch。

当您提供一个requirements.txt文件时,Dataflow将使用--no-binary标志进行安装,这将阻止安装Wheel(.whl)包,并且只允许源分发(。焦油gz). 我决定在谷歌数据流上为我自己的PyTrac发布自己的源代码,其中有一半C++和部分CUDA,还有一部分知道傻瓜的任务。在

谢谢你们一直以来的投入。在

正如您所正确指出的,开箱即用的pytorch-transformers似乎不支持这一点,但主要是因为它不将文件链接识别为URL。在

经过一番搜索,我在this source file中找到了相应的错误消息,大约在第144-155行附近。在

当然,您可以尝试将您的'gs'标记添加到第144行,然后将您与GCS的连接解释为http请求(第269-272行)。如果地面军事系统接受这一点,那么这应该是唯一需要改变才能工作的事情。
如果这不起作用,唯一直接的解决办法就是实现类似于amazons3bucket函数的东西,但是我对S3和GCS bucket的了解还不够,无法在这里做出任何有意义的判断。在

我对Pythorch或Roberta模型了解不多,但我会尽量回答您关于GCS的询问:

1.—“那么有没有办法加载存储在GCS中的预训练模型?”在

如果您的模型可以直接从二进制文件加载Blob:

from google.cloud import storage

client = storage.Client()
bucket = client.get_bucket("bucket name")
blob = bucket.blob("path_to_blob/blob_name.ext")
data = blob.download_as_string() # you will have your binary data transformed into string here.

2.-“在这种情况下,有没有一种方法可以在执行公共URL请求时进行身份验证?”在

这是一个棘手的部分,因为根据运行脚本的上下文,它将使用默认服务帐户进行身份验证。因此,当您使用官方GCP LIB时,您可以:

A.-授予默认服务帐户访问bucket/对象的权限。在

B.-创建一个新的服务帐户,并在脚本中对其进行身份验证(您还需要为该服务帐户生成身份验证令牌):

^{pr2}$

但是这是有效的,因为官方libs在后台处理对API调用的身份验证,因此在from_pretrained()函数的情况下不起作用。在

因此,另一种方法是将对象公开,这样您就可以在使用公共url时访问它。在

3.-“即使有一种方法可以验证,子目录的不存在是否仍然是一个问题?”在

不确定你的意思是这里,你可以在你的桶里有文件夹。在

相关问题 更多 >

    热门问题