본문 바로가기
AI

ML07-4-MNIST-data

by 월곡동로봇팔 2019. 10. 24.

MNIST Dataset

MNIST Dataset은 실제로 손글씨를 data화 시킨 것으로 우리가 흔히 training set으로 사용한다.

아래 그림을 보면 784개의 픽셀 data로 이루어져있다.

이 dataset을 training 한 후, 우리가 예측하고자 하는 숫자를 넣으면 된다.

 

MNIST data

 

1. input_data.py
TensorFlow 샘플에 포함된 예제인데, mnist 데이터셋이 없을 경우 인터넷으로부터 다운로드한다.
추가로 DataSet 클래스 등의 정의가 이 안에 들어 있다.
mnist 데이터셋을 다루는 코드의 꼭대기에는 대부분 input_data.py를 import하는 코드가 들어 있다.

2. mnist 로딩

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


one_hot은 label 데이터셋을 만들 때, label을 one-hot 방식으로 처리할 것인지를 가리킨다.
기본값은 False이기 때문에 사용할 경우 반드시 True를 전달해야 한다.
데이터가 one-hot 방식으로 넘어오면 처리하는 시점에 변환하지 않아도 되므로 편하다.

3. mnist 자료형
앞의 코드에서 반환한 mnist 변수의 자료형은 DataSets 클래스이다.
input_data.py 파일에 정의된 클래스로 mnist 샘플에서만 사용하는 임시 클래스이다.

DataSets 클래스는 구조체처럼 사용하기 위해 만든 클래스다.
train, validation, test의 3개 멤버 변수만 갖고 있고, 이들은 모두 DataSet 클래스이다.

 

DataSet 클래스는 말 그대로 한 데이터를 가지는 dataset이며 

DataSets 구조체는 Dataset을 element로 가지고 담고 있는 구조체이다.


cf) 구조체: 하나 이상의 변수를 묶어서 그룹화하는 사용자 정의 자료형.

from collections import namedtuple
MyStruct = namedtuple("MyStruct", "field1 field2 field3")
m = MyStruct(field1 = "foo", field2 = "bar", field3 = "baz")
print(m.field1)

namedtuple Class를 import 해서 (구조체이름, 구조체 변수) 타입으로 입력해준후, m이라는 객체를 생성.

객체 생성할 때, 옵션으로 변수에 값을 입력을 하면 실제로 객체에 변수가 들어간다. 

이와 같은 개념이고 이를 DataSet과 DataSets로 이용한다.

# base.py
Dataset = collections.namedtuple('Dataset', ['data', 'target'])
Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test'])

mnist에서는 Dataset이라는 구조체는 변수를 'data', 'target'으로 가지고 있다.

Datasets이라는 구조체는 변수를 'train', 'validation', 'test' 으로 가지고 있다.

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

 

 

4. test, validation, test

다운로드한 데이터는 55000개의 학습 데이터 (mnist.train), 10000개의 테스트 데이터 (mnist.test) 및 5000개의 검증 데이터 (mnist.validation) 세 부분으로 나누어져 있다.  이렇게 나눠진 것은 굉장히 중요하다.

학습하지 않은 별도의 데이터를 이용해서 학습한 결과가 실제로 일반적으로 적용되는지 검증하는 것이 기계 학습의 핵심이다.

앞에서 말했듯, 모든 MNIST 데이터 포인트들은 실제 데이터의 image, 데이터를 구분하는 label 두 부분으로 구성된다. 손으로 쓴 숫자와 그에 해당하는 라벨이다. 우리는 이미지를 "xs"라고 부르고 라벨을 "ys"라고 부를것이다.

 

+


이들의 자료형은 DataSet 클래스라고 얘기했다.
DataSet 클래스는 mnist 샘플에서 가져다 쓸 수 있도록 다양한 멤버들을 갖추고 있다.

이들의 자료형은 모두 numpy의 다차원 배열인 ndarray.
차원 변환과 transpose를 할 수 있고, 행렬 연산도 지원하기 때문에 당연히 numpy가 되는 것이 맞다.

train.images 데이터셋을 출력한 결과. 55,000x784 행렬. 1차원으로 하면 43,120,000개.
<class 'numpy.ndarray'> (55000, 784) 43120000

5. DataSet 클래스

  @property
  def images(self):
    return self._images

  @property
  def labels(self):
    return self._labels

  @property
  def num_examples(self):
    return self._num_examples

  @property
  def epochs_completed(self):
    return self._epochs_completed

  def next_batch(self, batch_size, fake_data=False, shuffle=True):
  ---

mnist.(  ).images : 이미지 데이터셋
mnist.(  ).labels : label 데이터셋
mnist.(  ).num_examples : 데이터 갯수
mnist.(  ).next_batch : 데이터셋으로부터 필요한 만큼의 데이터를 반환하는 함수

(  )안에는 train, validation, test가 들어갈 수 있다. 

 

train, validation, test도 read_data_sets 를 보면 Dataset으로 정의해놨기 때문이다.
print('갯수 :', len(mnist.train.images))
print('갯수 :", mnist.train.num_examples)

# train 데이터셋으로부터 데이터 100개 가져오기. (이미지, label) 튜플
train_images, train_labels = mnist.train.next_batch(100)

 

# mnist.py
# import base 이미 한 상황이다.
def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000,
                   seed=None,
                   source_url=DEFAULT_SOURCE_URL):
  # 여긴 fake일 경우 아무것도 return X
  if fake_data:
    
    def fake():
      return DataSet(
          [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

  if not source_url:  # empty string check
    source_url = DEFAULT_SOURCE_URL
  #Train, Test 들은 이미 폴더안에있는 이름들 미리 정해둠
  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

  # local_file로 다운로드 받은 파일들을 각각 변수에 open함수로 읽은 후 집어넣음
  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   source_url + TRAIN_IMAGES)
  with gfile.Open(local_file, 'rb') as f:
    train_images = extract_images(f)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   source_url + TRAIN_LABELS)
  with gfile.Open(local_file, 'rb') as f:
    train_labels = extract_labels(f, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   source_url + TEST_IMAGES)
  with gfile.Open(local_file, 'rb') as f:
    test_images = extract_images(f)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   source_url + TEST_LABELS)
  with gfile.Open(local_file, 'rb') as f:
    test_labels = extract_labels(f, one_hot=one_hot)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError('Validation size should be between 0 and {}. Received: {}.'
                     .format(len(train_images), validation_size))
                     
  # validation_size는 DataSet의 클래스 위에서 속성으로 5000이라 정해줌
  # 응용하면 나중에 validation 구간을 미리 정할 수 있다.
  # validation_size = 10000으로 바꿀수도!!
  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]

  options = dict(dtype=dtype, reshape=reshape, seed=seed)

  train = DataSet(train_images, train_labels, **options)
  validation = DataSet(validation_images, validation_labels, **options)
  test = DataSet(test_images, test_labels, **options)

  return base.Datasets(train=train, validation=validation, test=test)

 

마지막에 train, validation, test 들은 모두 DataSet으로 정의되었고, 

이들이 return되어서 Datasets으로 들어갔다. 후에 mnist로 들어간다.


출처: https://pythonkim.tistory.com/46 [파이쿵]

 

mnist 데이터셋 정리

TensorFlow 샘플에 보면 mnist 데이터셋이 많이 등장한다. mnist를 잘 알면, 이후 코드를 보는데 도움이 될 것 같아서 정리해 놓는다. 1. input_data.py TensorFlow 샘플에 포함된 예제인데, mnist 데이터셋이 없..

pythonkim.tistory.com

댓글