2024. 6. 3. 03:30ㆍMOOC
학습 목표
- PyTorch의 Dataset 클래스를 Custom Dataset을 만들 수 있다.
- PyTorch Dataset을 입력으로 받아 PyTorch DataLoader를 만들 수 있다.
[docs]class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (str or ``pathlib.Path``): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
- 이 부분은 클래스의 문서화 문자열(docstring)으로, 클래스와 초기화 매개변수에 대한 설명을 포함한다.
- root, train, download, transform, target_transform 매개변수를 통해 데이터셋의 루트 디렉토리, 훈련 또는 테스트 데이터셋 선택, 데이터셋 다운로드 여부, 입력 데이터 및 타겟 변환 함수 등을 설정할 수 있다
mirrors = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
]
resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
]
training_file = "training.pt"
test_file = "test.pt"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
- mirrors: 데이터셋을 다운로드할 수 있는 URL 목록이다.
- resources: 데이터셋 파일과 해당 파일의 MD5 해시값 목록이다.
- training_file, test_file: 훈련 및 테스트 데이터 파일명이다.
- classes: 데이터셋의 클래스(숫자 0부터 9까지)를 문자열로 표현한 목록이다.
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
- train_labels, test_labels, train_data, test_data: 데이터셋의 레이블과 데이터를 접근하기 위한 속성이다.
- 각 접근자 메서드는 warnings.warn을 사용하여 해당 속성이 targets 또는 data로 이름이 변경되었음을 경고한다. 이후 변경된 속성을 반환한다.
주요 클래스
MNIST 클래스
MNIST 클래스는 MNIST 데이터셋을 다루기 위한 클래스로, 주요 기능은 다음과 같다:
__init__ 메서드
- 데이터셋의 루트 디렉토리, 학습용 데이터셋 여부, 변환 함수, 다운로드 여부 등의 인자를 받는다.
- 데이터셋이 이미 존재하면 로드하고, 존재하지 않으면 다운로드한다.
메서드 시그니처 및 매개변수
def __init__(
self,
root: Union[str, Path],
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
- root: 데이터셋의 루트 디렉토리 경로. 문자열(str) 또는 pathlib.Path 객체로 제공할 수 있다.
- train: 훈련용 데이터셋인지 테스트용 데이터셋인지를 나타내는 불리언 값. 기본값은 True로 훈련용 데이터셋을 의미한다.
- transform: 입력 데이터에 적용할 변환 함수 또는 호출 가능한 객체. 기본값은 None이다.
- target_transform: 타겟 데이터(레이블)에 적용할 변환 함수 또는 호출 가능한 객체. 기본값은 None이다.
- download: 데이터셋이 로컬에 없을 경우 다운로드할지 여부를 나타내는 불리언 값. 기본값은 False이다.
상위 클래스 초기화
super().__init__(root, transform=transform, target_transform=target_transform)
- 상위 클래스의 __init__ 메서드를 호출하여 기본 속성들을 초기화한다. 이때 root, transform, target_transform을 전달한다.
self.train = train # training set or test set
self.train 속성을 설정하여 훈련용 데이터셋인지 테스트용 데이터셋인지를 지정한다.
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
- _check_legacy_exist 메서드를 호출하여 레거시 데이터가 존재하는지 확인한다.
- 레거시 데이터가 존재할 경우 _load_legacy_data 메서드를 호출하여 데이터를 로드하고, self.data와 self.targets 속성에 저장한 후 초기화 메서드를 종료한다(return).
if download:
self.download()
download 플래그가 True로 설정되어 있을 경우, download 메서드를 호출하여 데이터셋을 다운로드한다.
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
- _check_exists 메서드를 호출하여 데이터셋이 로컬에 존재하는지 확인한다.
- 데이터셋이 존재하지 않을 경우 RuntimeError를 발생시켜 에러 메시지를 출력한다.
self.data, self.targets = self._load_data()
_load_data 메서드를 호출하여 데이터를 로드하고, self.data와 self.targets 속성에 저장한다.
전체 초기화 과정 요약
- 상위 클래스 초기화: super().__init__(...)를 호출하여 기본 속성을 초기화한다.
- 속성 설정: self.train을 설정하여 훈련용 데이터셋인지 테스트용 데이터셋인지를 지정한다.
- 레거시 데이터 체크: _check_legacy_exist()를 호출하여 레거시 데이터가 존재할 경우, 레거시 데이터를 로드하고 초기화를 종료한다.
- 데이터셋 다운로드: download 플래그가 True일 경우, download() 메서드를 호출하여 데이터셋을 다운로드한다.
- 데이터셋 존재 여부 체크: _check_exists()를 호출하여 데이터셋이 존재하지 않을 경우, 에러를 발생시킨다.
- 데이터 로드: _load_data()를 호출하여 데이터를 로드하고, self.data와 self.targets에 저장한다.
이 코드에서는 MNIST 데이터셋을 로드한다. MNIST 데이터셋은 손으로 쓴 숫자(0-9)의 이미지 데이터셋으로, 머신러닝과 딥러닝 모델을 훈련하고 평가하는 데 자주 사용된다.
MNIST 데이터셋 구성
- 이미지 데이터:
- 각 이미지는 28x28 픽셀의 흑백 이미지이다.
- 훈련용 데이터셋에는 60,000개의 이미지가 있고, 테스트용 데이터셋에는 10,000개의 이미지가 있다.
- 레이블 데이터:
- 각 이미지는 0부터 9까지의 숫자 중 하나에 해당하는 레이블이 있다.
- 레이블은 이미지에 표시된 숫자를 나타낸다.
- 훈련용 데이터셋 (Training Dataset):
- train=True로 설정된 경우 훈련용 데이터셋을 로드한다.
- 훈련용 이미지 파일: train-images-idx3-ubyte.gz
- 훈련용 레이블 파일: train-labels-idx1-ubyte.gz
- 이 데이터셋은 모델을 훈련시키기 위해 사용된다.
- 테스트용 데이터셋 (Test Dataset):
- train=False로 설정된 경우 테스트용 데이터셋을 로드한다.
- 테스트용 이미지 파일: t10k-images-idx3-ubyte.gz
- 테스트용 레이블 파일: t10k-labels-idx1-ubyte.gz
- 이 데이터셋은 훈련된 모델을 평가하기 위해 사용된다.
레거시 데이터: 이전 버전의 코드나 시스템에서 사용되던 데이터 형식.
def _check_legacy_exist(self):
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False
return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)
def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file), weights_only=True)
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
[docs] def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)
@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "raw")
@property
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "processed")
@property
def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self) -> bool:
return all(
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
for url, _ in self.resources
)
def download(self) -> None:
"""Download the MNIST data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
# download files
for filename, md5 in self.resources:
for mirror in self.mirrors:
url = f"{mirror}{filename}"
try:
print(f"Downloading {url}")
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
except URLError as error:
print(f"Failed to download (trying next):\n{error}")
continue
finally:
print()
break
else:
raise RuntimeError(f"Error downloading {filename}")
def extra_repr(self) -> str:
split = "Train" if self.train is True else "Test"
return f"Split: {split}"
- 레거시 데이터 존재 여부 확인:
- _check_legacy_exist() 메서드를 사용하여 레거시 데이터가 존재하는지 확인한다. 만약 존재하면 _load_legacy_data() 메서드를 통해 데이터를 로드하고, 초기화를 종료한다.
- 데이터 다운로드 여부 확인:
- download 플래그가 True로 설정된 경우, download() 메서드를 호출하여 필요한 데이터를 인터넷에서 다운로드한다. 다운로드 후 파일의 무결성을 확인한다.
- 데이터셋 존재 여부 확인:
- _check_exists() 메서드를 호출하여 데이터셋이 이미 존재하는지 확인한다. 데이터셋이 존재하지 않으면 RuntimeError를 발생시킨다.
- 데이터 로드:
- _load_data() 메서드를 호출하여 훈련용 또는 테스트용 데이터를 로드한다. 이미지 파일과 레이블 파일을 읽어와 self.data와 self.targets에 저장한다.
- 데이터 항목 반환:
- __getitem__() 메서드를 통해 특정 인덱스의 이미지와 레이블을 반환한다. 필요시 변환을 적용하여 반환한다.
- 데이터셋 크기 반환:
- __len__() 메서드를 통해 데이터셋의 총 항목 수를 반환한다.
- 폴더 경로 제공:
- raw_folder와 processed_folder 속성을 통해 원본 데이터와 처리된 데이터의 폴더 경로를 제공한다.
- 클래스 인덱스 매핑:
- class_to_idx 속성을 통해 클래스 이름을 인덱스로 매핑하는 딕셔너리를 제공한다.
- 추가 정보 반환:
- extra_repr() 메서드를 통해 데이터셋의 상태를 나타내는 문자열을 반환한다.
class QMNIST(MNIST):
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
Args:
root (str or ``pathlib.Path``): Root directory of dataset whose ``raw``
subdir contains binary files of the datasets.
what (string,optional): Can be 'train', 'test', 'test10k',
'test50k', or 'nist' for respectively the mnist compatible
training set, the 60k qmnist testing set, the 10k qmnist
examples that match the mnist testing set, the 50k
remaining qmnist testing examples, or all the nist
digits. The default is to select 'train' or 'test'
according to the compatibility argument 'train'.
compat (bool,optional): A boolean that says whether the target
for each example is class number (for compatibility with
the MNIST dataloader) or a torch vector containing the
full qmnist information. Default=True.
download (bool, optional): If True, downloads the dataset from
the internet and puts it in root directory. If dataset is
already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that
takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform
that takes in the target and transforms it.
train (bool,optional,compatibility): When argument 'what' is
not specified, this boolean decides whether to load the
training set or the testing set. Default: True.
"""
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
"train": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
"ed72d4157d28c017586c42bc6afe6370",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
"0058f8dd561b90ffdd0f734c6a30e5e4",
),
],
"test": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
"1394631089c404de565df7b7aeaf9412",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
"5b5b05890a5e13444e108efe57b788aa",
),
],
"nist": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
"7f124b3b8ab81486c9d8c2749c17f834",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
"5ed0e788978e45d4a8bd4b7caec3d79d",
),
],
}
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
클래스 설명
클래스 선언: QMNIST 클래스는 MNIST 클래스를 상속받아 정의된다.
클래스 속성과 메서드
클래스 문서화 문자열
- 클래스에 대한 설명과 초기화 매개변수에 대한 설명이 포함되어 있다.
- 주요 매개변수:
- root: 데이터셋이 저장될 루트 디렉토리 경로.
- what: 어떤 데이터셋을 로드할지 결정하는 문자열로, 'train', 'test', 'test10k', 'test50k', 'nist' 중 하나를 선택할 수 있다.
- compat: MNIST와의 호환성을 위해 타겟이 클래스 번호인지, QMNIST의 전체 정보를 포함하는지 결정하는 불리언 값.
- download: 데이터셋을 인터넷에서 다운로드할지 여부를 결정하는 불리언 값.
- transform: PIL 이미지를 받아 변환된 버전을 반환하는 함수 또는 호출 가능한 객체.
- target_transform: 타겟을 받아 변환된 버전을 반환하는 함수 또는 호출 가능한 객체.
- train: what 매개변수가 지정되지 않은 경우, 훈련용 데이터셋을 로드할지 테스트용 데이터셋을 로드할지 결정하는 불리언 값.
subsets 딕셔너리
- what 매개변수의 값과 실제 데이터셋 파일 간의 매핑을 정의한다.
- 예를 들어, test10k와 test50k는 모두 test 데이터 파일을 사용한다.
resources 딕셔너리
- 각 데이터셋(train, test, nist)에 대해 다운로드할 파일의 URL과 해당 파일의 MD5 해시 값을 포함한다.
- 이 딕셔너리는 데이터셋을 다운로드할 때 사용된다.
classes 리스트
- 숫자 0부터 9까지의 클래스 이름을 문자열로 정의한 리스트이다.
- 각 클래스는 "숫자 - 영문" 형식으로 정의된다.
요약
- 이 클래스는 QMNIST 데이터셋을 처리하기 위해 설계되었다.
- subsets 딕셔너리를 통해 what 매개변수의 값과 실제 데이터셋 파일을 매핑한다.
- resources 딕셔너리를 통해 각 데이터셋에 대한 다운로드 URL과 파일의 MD5 해시 값을 정의한다.
- classes 리스트를 통해 숫자 0부터 9까지의 클래스 이름을 정의한다.
def __init__(
self, root: Union[str, Path], what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
) -> None:
if what is None:
what = "train" if train else "test"
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
self.compat = compat
self.data_file = what + ".pt"
self.training_file = self.data_file
self.test_file = self.data_file
super().__init__(root, train, **kwargs)
@property
def images_file(self) -> str:
(url, _), _ = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
@property
def labels_file(self) -> str:
_, (url, _) = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
if data.dtype != torch.uint8:
raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
if data.ndimension() != 3:
raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
if targets.ndimension() != 2:
raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
if self.what == "test10k":
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
elif self.what == "test50k":
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
return data, targets
def download(self) -> None:
"""Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what').
"""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
split = self.resources[self.subsets[self.what]]
for url, md5 in split:
download_and_extract_archive(url, self.raw_folder, md5=md5)
[docs] def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.compat:
target = int(target[0])
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def extra_repr(self) -> str:
return f"Split: {self.what}"
def get_int(b: bytes) -> int:
return int(codecs.encode(b, "hex"), 16)
SN3_PASCALVINCENT_TYPEMAP = {
8: torch.uint8,
9: torch.int8,
11: torch.int16,
12: torch.int32,
13: torch.float32,
14: torch.float64,
}
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
with open(path, "rb") as f:
data = f.read()
# parse
if sys.byteorder == "little":
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
else:
nd = get_int(data[0:1])
ty = get_int(data[1:2]) + get_int(data[2:3]) * 256 + get_int(data[3:4]) * 256 * 256
assert 1 <= nd <= 3
assert 8 <= ty <= 14
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
if sys.byteorder == "big":
for i in range(len(s)):
s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
# The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
# that is little endian and the dtype has more than one byte, we need to flip them.
if sys.byteorder == "little" and parsed.element_size() > 1:
parsed = _flip_byte_order(parsed)
assert parsed.shape[0] == np.prod(s) or not strict
return parsed.view(*s)
def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 1:
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
return x.long()
def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 3:
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
return x
클래스 초기화 및 속성 설정
- 클래스 초기화 (__init__ 메서드):
- root, what, compat, train 등의 매개변수를 받는다.
- what이 지정되지 않은 경우, train 매개변수에 따라 기본값을 'train' 또는 'test'로 설정한다.
- self.what 속성을 설정하고, 이를 검증한다.
- compat 속성을 설정하여 호환성을 관리한다.
- 데이터 파일 이름을 결정하고, 이를 training_file과 test_file 속성에 저장한다.
- 부모 클래스의 초기화 메서드를 호출하여 기본 설정을 마친다.
- 파일 경로 속성 (images_file, labels_file 속성):
- images_file: 데이터셋의 이미지 파일 경로를 반환한다.
- labels_file: 데이터셋의 레이블 파일 경로를 반환한다.
데이터 체크 및 로드
- 데이터 존재 여부 확인 (_check_exists 메서드):
- 이미지 파일과 레이블 파일이 존재하고 무결성이 확인되면 True를 반환한다.
- 데이터 로드 (_load_data 메서드):
- 이미지 파일과 레이블 파일을 읽어 텐서로 변환한다.
- 데이터와 레이블의 형식과 차원을 검증한다.
- 특정 서브셋(test10k, test50k)의 경우, 데이터와 레이블을 부분적으로 복제하여 반환한다.
데이터 다운로드
- 데이터 다운로드 (download 메서드):
- 데이터셋이 존재하지 않는 경우, 필요한 파일들을 지정된 URL에서 다운로드하고 압축을 해제한다.
데이터 접근
- 데이터 항목 반환 (__getitem__ 메서드):
- 특정 인덱스의 이미지와 타겟(레이블)을 반환한다.
- 필요시 변환을 적용하고, compat 속성에 따라 타겟을 처리한다.
- 추가 정보 반환 (extra_repr 메서드):
- 데이터셋의 상태를 나타내는 문자열을 반환한다.
데이터 읽기 유틸리티
- SN3 파일 형식 읽기 (read_sn3_pascalvincent_tensor 메서드):
- Pascal Vincent 형식의 SN3 파일을 읽어 텐서로 변환한다.
- 파일의 매직 넘버, 차원, 데이터 타입 등을 확인하고, 데이터를 읽어 적절한 텐서로 변환한다.
- 레이블 파일 읽기 (read_label_file 메서드):
- 레이블 파일을 읽어 텐서로 변환하고, 형식과 차원을 검증한다.
- 이미지 파일 읽기 (read_image_file 메서드):
- 이미지 파일을 읽어 텐서로 변환하고, 형식과 차원을 검증한다.
Datasets & DataLoaders
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
- datasets.FashionMNIST: FashionMNIST 데이터셋을 로드하는 클래스.
- root="data": 데이터가 저장될 경로를 지정.
- train=True: 학습 데이터를 로드할지 여부를 지정.
- download=True: 데이터가 없을 경우 인터넷에서 다운로드하도록 지정.
- transform=ToTensor(): 이미지를 텐서로 변환하는 변환기를 지정.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
- train=False: 테스트 데이터를 로드할지 여부를 지정.
- 나머지 매개변수는 학습 데이터 로드와 동일하게 지정.
이 코드는 PyTorch와 torchvision을 사용하여 FashionMNIST 데이터셋의 훈련 및 테스트 데이터를 다운로드하고, 각 이미지를 텐서 형태로 변환하여 로드하는 과정을 보여준다.
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
- 레이블 맵핑 설정:
- FashionMNIST 데이터셋의 레이블을 의류 이름으로 변환하기 위해 labels_map 딕셔너리를 생성한다. 이 딕셔너리는 각 레이블(0~9)을 대응되는 의류 이름으로 매핑한다.
- Figure 객체 생성:
- 이미지를 시각화할 figure 객체를 생성한다. 이 객체의 크기는 8x8 인치로 설정한다.
- 그리드 크기 설정:
- 이미지가 3x3 그리드에 표시되도록 열(cols)과 행(rows) 수를 각각 3으로 설정한다.
- 랜덤 샘플 선택 및 시각화:
- 3x3 그리드에 이미지를 채우기 위해 9번 반복하는 for 루프를 실행한다.
- 각 반복마다 training_data에서 랜덤으로 인덱스를 선택한다.
- 선택한 인덱스의 샘플 이미지와 레이블을 가져온다.
- figure 객체에 새로운 서브플롯을 추가하여 그리드에 이미지를 추가할 위치를 지정한다.
- 레이블을 labels_map을 통해 의류 이름으로 변환하여 서브플롯의 제목으로 설정한다.
- 축을 숨겨 이미지만 보이도록 한다.
- 이미지를 그레이스케일로 표시한다.
- 모든 서브플롯이 추가되면 plt.show()를 호출하여 완성된 그리드를 화면에 표시한다.
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
- 필요한 라이브러리 임포트:
- os 라이브러리는 파일 경로를 조작하기 위해 사용되고, pandas는 CSV 파일을 읽고 데이터프레임으로 변환하는 데 사용된다. torchvision.io.read_image는 이미지 파일을 읽어 텐서로 변환하는 데 사용된다.
- 사용자 정의 데이터셋 클래스 정의:
- torch.utils.data.Dataset을 상속받아 CustomImageDataset이라는 새로운 클래스를 정의한다. 이 클래스는 이미지와 레이블을 포함하는 데이터셋을 표현한다.
- 클래스 초기화 메서드:
- 클래스가 초기화될 때 실행되는 메서드로, 이미지와 레이블 정보를 담고 있는 CSV 파일 경로와 이미지가 저장된 디렉토리 경로를 입력받는다. 또한, 이미지와 레이블에 적용할 수 있는 선택적 변환 함수들도 입력받는다. 이 메서드는 CSV 파일을 읽어 self.img_labels에 저장하고, 이미지 디렉토리 경로와 변환 함수들을 각각 클래스 속성으로 저장한다.
- 데이터셋의 길이를 반환하는 메서드:
- 데이터셋의 총 샘플 수를 반환하는 메서드로, self.img_labels의 길이를 반환하여 데이터셋의 크기를 알려준다.
- 특정 인덱스의 데이터를 반환하는 메서드:
- 주어진 인덱스에 해당하는 이미지와 레이블을 반환하는 메서드이다. 먼저, 이미지 파일의 경로를 생성하고, read_image를 사용하여 이미지를 읽어온다. 그런 다음, 데이터프레임에서 해당 인덱스의 레이블을 가져온다. 이미지와 레이블에 변환 함수가 지정되어 있다면, 각각 변환을 적용한다. 마지막으로, 변환된 이미지와 레이블을 반환한다.
__init__
__init__ 함수는 Dataset 객체를 인스턴스화할 때 한 번 실행된다. 우리는 이 함수에서 이미지를 포함하는 디렉토리, 주석 파일, 그리고 두 가지 변환(transform)을 초기화한다 (다음 섹션에서 더 자세히 다룰 것이다).
The labels.csv file looks like:
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
이 초기화 메서드는 CustomImageDataset 클래스가 생성될 때 다음과 같은 작업을 수행한다:
- CSV 파일을 읽어 이미지 파일 이름과 레이블을 포함하는 데이터프레임을 생성하고, 이를 클래스 인스턴스의 속성으로 저장한다.
- 이미지 파일들이 저장된 디렉토리 경로를 클래스 인스턴스의 속성으로 저장한다.
- 선택적으로 이미지를 변환할 수 있는 변환 함수와 레이블을 변환할 수 있는 변환 함수를 클래스 인스턴스의 속성으로 저장한다.
def __len__(self):
return len(self.img_labels)
데이터셋의 크기를 반환
__getitem__
__getitem__ 함수는 주어진 인덱스 idx에 해당하는 데이터셋의 샘플을 로드하고 반환한다. 인덱스를 기반으로 이미지의 디스크 상 위치를 식별하고, 이를 read_image를 사용하여 텐서로 변환한다. 또한, self.img_labels에 있는 CSV 데이터에서 해당하는 레이블을 가져오며, 적용 가능한 변환 함수들을 호출한 후 텐서 이미지와 해당 레이블을 튜플로 반환한다.
def __getitem__(self,idx):
img_path=os.path.join(self.img_dir,self.img_labels.iloc[idx,0])
image=read_image(img_path)
label=self.ilmg_labels.iloc[idx,1]
if self.transform:
iamge=self.transform
if self.target_transform:
label=self.target_transform(label)
return image, label
- 이미지 파일 경로 생성:
- 주어진 인덱스에 해당하는 이미지 파일 이름을 데이터프레임에서 가져와, 이미지 디렉토리 경로와 결합하여 전체 이미지 파일 경로를 생성한다.
- 이미지 읽기:
- 생성된 이미지 파일 경로를 사용하여 이미지를 읽고, 이를 텐서 형식으로 변환한다.
- 레이블 읽기:
- 주어진 인덱스에 해당하는 레이블을 데이터프레임에서 가져온다.
- 이미지 변환 적용:
- transform이 정의되어 있으면, 이미지를 변환한다.
- 레이블 변환 적용:
- target_transform이 정의되어 있으면, 레이블을 변환한다.
- 이미지와 레이블 반환:
- 변환된 이미지와 레이블을 튜플로 반환하여, 데이터셋의 해당 샘플을 가져온다.
데이터 준비하기: DataLoader를 사용한 훈련
Dataset은 데이터셋의 특징(features)과 레이블(labels)을 한 번에 하나씩 가져온다. 모델을 훈련할 때는 일반적으로 샘플을 "미니배치(minibatches)"로 전달하고, 매 에포크마다 데이터를 다시 섞어 모델의 과적합을 줄이며, 데이터 검색 속도를 높이기 위해 파이썬의 멀티프로세싱을 사용하고자 한다.
DataLoader는 이러한 복잡성을 쉽게 처리할 수 있는 API로 추상화한 반복자(iterable)이다.
from torch.utils.data import DataLoader
train_dataloader=DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader=DataLoader(test_data,batch_size=64,shuffle=True)
이 코드는 두 개의 DataLoader 객체를 생성한다. 하나는 학습용 데이터(training_data)를, 다른 하나는 테스트용 데이터(test_data)를 로드하는 역할
DataLoader를 통해 데이터 반복하기
우리는 데이터셋을 DataLoader에 로드했고, 필요에 따라 데이터셋을 반복(iterate)할 수 있다. 아래의 각 반복(iteration)은 train_features와 train_labels의 배치를 반환하며, 각각 배치 크기(batch_size)=64의 특징(features)과 레이블(labels)을 포함하고 있다. shuffle=True를 지정했기 때문에, 모든 배치를 반복한 후에는 데이터가 섞인다. 데이터 로딩 순서에 대한 더 세밀한 제어를 위해서는 Samplers를 참고하면 된다.
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
DataLoader를 사용하여 배치 단위로 데이터를 로드하고, 첫 번째 배치에서 첫 번째 이미지를 선택하여 시각화
'MOOC' 카테고리의 다른 글
[프로젝트로 배우는 데이터사이언스] pima_classification_baseline (1) | 2024.06.27 |
---|---|
PyTorch 프로젝트 구조 이해하기 (0) | 2024.06.06 |
[파이토치로 만드는 딥러닝 이론3] 모델 저장하기 (1) | 2024.06.03 |
[파이토치로 만드는 딥러닝 이론1] nn.Module & nn.Parameter & Backward (0) | 2024.05.31 |
[딥러닝] 흉부 엑스레이 이미지 폐렴(PNEUMONIA) 분류 실습 (0) | 2024.05.09 |