PyTorch で tar とか tar.gz とかでアーカイブしたデータセットを直接読みたい
なんでこんなことがしたいのか?
Colaboratory で何か機械学習モデルを trian したいとき、データセットは Google Drive に置いてマウントして読むというのは定番の手法 しかし、数千~数万ファイルあるデータを Google Drive にアップロードするのはめちゃくちゃ面倒だしすごい時間もかかる
tar とか tar.gz とかで単一のファイルにしてやればアップロードはすぐに終わる
容量として減らなくても単にファイル数が減るだけで所要時間が激減する
なので、これをいちいち展開せずに直接読んで ImageFolder 的に使えないか?という試み
Python には tarfile モジュールがあるのでこれを使えばなんとかなりそう とりあえずぱっと書いた
code:py
class TarGzImageFolderDataset(torch.utils.data.Dataset):
def __init__(self, file_path: Path, prefix: str=""):
self.prefix = prefix
self.file = tarfile.open(file_path, mode="r:gz")
self.file_paths = []
for info in self.file.getmembers():
if not info.isfile():
continue
self.file_paths.append(info.name)
def __del__(self):
self.file.close()
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx: int):
file = self.file.extractfile(self.file_pathsidx) assert file is not None
image = torchvision.io.decode_image(torch.frombuffer(file.read(), dtype=torch.uint8))
label = str(Path(self.file_pathsidx).parent.relative_to(Path(self.prefix))) return image, label
tar 内のパスが prefix/foobar/image.png とかだったときに foobar という文字列をラベルとして返してるのが良くないがとりあえず動いた