将CSV解析为Pytorch张量

2024-05-08 21:47:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个CSV文件,除了标题行之外,所有数值都有。当尝试构建张量时,出现以下异常:

Traceback (most recent call last):
  File "pytorch.py", line 14, in <module>
    test_tensor = torch.tensor(test)
ValueError: could not determine the shape of object type 'DataFrame'

这是我的代码:

import torch
import dask.dataframe as dd

device = torch.device("cuda:0")

print("Loading CSV...")
test = dd.read_csv("test.csv", encoding = "UTF-8")
train = dd.read_csv("train.csv", encoding = "UTF-8")

print("Converting to Tensor...")
test_tensor = torch.tensor(test)
train_tensor = torch.tensor(train)

使用pandas而不是Dask进行CSV解析产生了相同的错误。我还试图在对torch.tensor(data)的调用中指定dtype=torch.float64,但再次出现了相同的错误。


Tags: 文件csvtestimportreaddevice错误train
2条回答

请先尝试将其转换为数组:

test_tensor = torch.Tensor(test.values)

我想你只是错过了.values

import torch
import pandas as pd

train = pd.read_csv('train.csv')
train_tensor = torch.tensor(train.values)

相关问题 更多 >