<p>主要编辑:</p>
<p>您可以在Dataflow worker上安装wheel文件,还可以使用worker temp存储在本地保存二进制文件!在</p>
<p>事实上(目前截止到2019年11月),你不能通过提供一个<code>--requirements</code>参数来做到这一点。相反,您必须像这样使用<code>setup.py</code>。假设CAPS中的任何常量都是在别处定义的。在</p>
<pre class="lang-py prettyprint-override"><code>REQUIRED_PACKAGES = [
'torch==1.3.0',
'pytorch-transformers==1.2.0',
]
setup(
name='project_dir',
version=VERSION,
packages=find_packages(),
install_requires=REQUIRED_PACKAGES)
</code></pre>
<p>运行脚本</p>
^{pr2}$
<p>在这个工作中,我们从GCS下载并在一个定制的数据流操作符的上下文中使用这个模型。为了方便起见,我们将一些实用程序方法包装在一个单独的模块中(这对于避免数据流依赖性上载很重要),并将它们导入到自定义运算符的本地范围,而不是全局范围。在</p>
<pre class="lang-py prettyprint-override"><code>class AddColumn(beam.DoFn):
PRETRAINED_MODEL = 'gs://my-bucket/blah/roberta-model-files'
def get_model_tokenizer_wrapper(self):
import shutil
import tempfile
import dataflow_util as util
try:
return self.model_tokenizer_wrapper
except AttributeError:
tmp_dir = tempfile.mkdtemp() + '/'
util.download_tree(self.PRETRAINED_MODEL, tmp_dir)
model, tokenizer = util.create_model_and_tokenizer(tmp_dir)
model_tokenizer_wrapper = util.PretrainedPyTorchModelWrapper(
model, tokenizer)
shutil.rmtree(tmp_dir)
self.model_tokenizer_wrapper = model_tokenizer_wrapper
logging.info(
'Successfully created PretrainedPyTorchModelWrapper')
return self.model_tokenizer_wrapper
def process(self, elem):
model_tokenizer_wrapper = self.get_model_tokenizer_wrapper()
# And now use that wrapper to process your elem however you need.
# Note that when you read from BQ your elements are dictionaries
# of the column names and values for each BQ row.
</code></pre>
<p>代码库中独立模块中的实用程序函数。在我们的项目根中,这是在dataflow\u util中/初始py但你不必那样做。在</p>
<pre class="lang-py prettyprint-override"><code>from contextlib import closing
import logging
import apache_beam as beam
import numpy as np
from pytorch_transformers import RobertaModel, RobertaTokenizer
import torch
class PretrainedPyTorchModelWrapper():
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def download_tree(gcs_dir, local_dir):
gcs = beam.io.gcp.gcsio.GcsIO()
assert gcs_dir.endswith('/')
assert local_dir.endswith('/')
for entry in gcs.list_prefix(gcs_dir):
download_file(gcs, gcs_dir, local_dir, entry)
def download_file(gcs, gcs_dir, local_dir, entry):
rel_path = entry[len(gcs_dir):]
dest_path = local_dir + rel_path
logging.info('Downloading %s', dest_path)
with closing(gcs.open(entry)) as f_read:
with open(dest_path, 'wb') as f_write:
# Download the file in chunks to avoid requiring large amounts of
# RAM when downloading large files.
while True:
file_data_chunk = f_read.read(
beam.io.gcp.gcsio.DEFAULT_READ_BUFFER_SIZE)
if len(file_data_chunk):
f_write.write(file_data_chunk)
else:
break
def create_model_and_tokenizer(local_model_path_str):
"""
Instantiate transformer model and tokenizer
:param local_model_path_str: string representation of the local path
to the directory containing the pretrained model
:return: model, tokenizer
"""
model_class, tokenizer_class = (RobertaModel, RobertaTokenizer)
# Load the pretrained tokenizer and model
tokenizer = tokenizer_class.from_pretrained(local_model_path_str)
model = model_class.from_pretrained(local_model_path_str)
return model, tokenizer
</code></pre>
<p>伙计们,你们有了!更多详细信息可在此处找到:<a href="https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/" rel="nofollow noreferrer">https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/</a></p>
<hr/>
<p><s>我发现,整个问题链都是无关紧要的,因为Dataflow只允许您在worker上安装源分发包,这意味着您无法实际安装PyTorch。</s></p>
<p>当您提供一个<code>requirements.txt</code>文件时,Dataflow将使用<code>--no-binary</code>标志进行安装,这将阻止安装Wheel(.whl)包,并且只允许源分发(。焦油gz). 我决定在谷歌数据流上为我自己的PyTrac发布自己的源代码,其中有一半C++和部分CUDA,还有一部分知道傻瓜的任务。在</p>
<p>谢谢你们一直以来的投入。在</p>