import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import json
train_images = np.load( '../dataset/train_image.npy' )
train_labels = np.load( '../dataset/train_label_3.npy' )
test_images = np.load( '../dataset/test_image.npy' )
test_labels = np.load( '../dataset/test_label_3.npy' )
train_labels = np.argmax( train_labels, axis = 1 )
test_labels = np.argmax( test_labels, axis = 1 )
train_images = ( train_images * 255 ) .astype( np.uint8)
test_images = ( test_images * 255 ) .astype( np.uint8)
class NumpyToPIL( object) :
def __call__( self, sample) :
return Image.fromarray( sample)
class CustomImageDataset( Dataset) :
def __init__( self, images, labels, transform = None) :
self.images = images
self.labels = labels
self.transform = transform
def __len__( self) :
return len( self.images)
def __getitem__( self, idx) :
image = self.images[ idx]
label = self.labels[ idx]
if self.transform:
image = self.transform( image)
return image, label
transform_train = transforms.Compose( [
NumpyToPIL( ) ,
transforms.Resize(( 224 , 224 )) ,
transforms.RandomHorizontalFlip( ) ,
transforms.ToTensor( ) ,
transforms.Normalize( mean= [ 0.485 , 0.456 , 0.406 ] , std = [ 0.229 , 0.224 , 0.225 ] ) ,
] )
transform_test = transforms.Compose( [
NumpyToPIL( ) ,
transforms.Resize(( 224 , 224 )) ,
transforms.ToTensor( ) ,
transforms.Normalize( mean= [ 0.485 , 0.456 , 0.406 ] , std = [ 0.229 , 0.224 , 0.225 ] ) ,
] )
dataset_train = CustomImageDataset( train_images, train_labels, transform = transform_train)
dataset_test = CustomImageDataset( test_images, test_labels, transform = transform_test)
train_loader = DataLoader( dataset_train, batch_size = BATCH_SIZE, num_workers = 8 , shuffle = True, drop_last = True)
test_loader = DataLoader( dataset_test, batch_size = BATCH_SIZE, shuffle = False)
train_labels = train_labels.ravel( )
test_labels = test_labels.ravel( )
train_class_to_idx = { str( i) : i for i in set( train_labels.tolist( )) }
test_class_to_idx = { str( i) : i for i in set( test_labels.tolist( )) }
with open( 'train_class.txt' , 'w' ) as file:
file.write( str( train_class_to_idx))
with open( 'train_class.json' , 'w' , encoding = 'utf-8' ) as file:
file.write( json.dumps( train_class_to_idx))
with open( 'test_class.txt' , 'w' ) as file:
file.write( str( test_class_to_idx))
with open( 'test_class.json' , 'w' , encoding = 'utf-8' ) as file:
file.write( json.dumps( test_class_to_idx))