我正在使用此代码读取cifar-10数据集,并希望找到一种方法来更改它的HSV。我试着将数据和标签列表放在函数之外,但得到了这个错误:UnboundLocalError: local variable 'data' referenced before assignment
如何提取这些列表,以便可以更改为HSV,然后更改为数据集的直方图。
import pickle
import numpy as np
from os.path import join
from os import listdir
import matplotlib.pyplot as plt
from tqdm import tqdm
import struct as st
class DataReader:
def __init__(self,root_dir,type='cifar-100'):
self.root_dir = root_dir
self.type = type
def get_dict_from_pickle(self):
self.train_dict = unpickle(join(self.root_dir,'train'))
self.test_dict = unpickle(join(self.root_dir,'test'))
def get_train_data(self):
if self.type == 'cifar-100':
self.get_dict_from_pickle()
data = np.array(self.train_dict[b'data'])
lbls_sub = np.array(self.train_dict[b'fine_labels'])
lbls_class = np.array(self.train_dict[b'coarse_labels'])
return data,lbls_class,lbls_sub
elif self.type == 'cifar-10':
#data = []
#labels = []
print("Reading")
for file_ in tqdm(listdir(self.root_dir)):
if file_.split('_')[0] == 'data':
dict = unpickle(join(self.root_dir,file_))
data.extend(dict[b'data'])
labels.extend(dict[b'labels'])
return np.array(data),np.array(labels),None
elif self.type =='mnist':
return self.read_mnist()
def get_test_data(self):
if self.type == 'cifar-100':
self.get_dict_from_pickle()
data = np.array(self.test_dict[b'data'])
lbls_sub = np.array(self.test_dict[b'fine_labels'])
lbls_class = np.array(self.test_dict[b'coarse_labels'])
return data,lbls_class,lbls_sub
elif self.type == 'cifar-10':
data = np.empty(shape=(0,3072))
labels = []
for file_ in listdir(self.root_dir):
if file_.split('_')[0] == 'test':
dict = unpickle(join(self.root_dir,file_))
data = np.vstack((data,dict[b'data']))
print(data[data.shape[0]-1])
labels.append(dict[b'labels'])
return np.array(data),np.array(labels),None
def reshape_to_plot(self,data):
if self.type == 'mnist':
return data.reshape(data.shape[0],28,28).astype("uint8")
return data.reshape(data.shape[0],3,32,32).transpose(0,2,3,1).astype("uint8")
def plot_imgs(self,in_data,n,random=False):
data = np.array([d for d in in_data])
data = self.reshape_to_plot(data)
x1 = min(n//2,5)
if x1 == 0:
x1 = 1
y1 = (n//x1)
x = min(x1,y1)
y = max(x1,y1)
fig, ax = plt.subplots(x,y,figsize=(5,5))
i=0
for j in range(x):
for k in range(y):
if random:
i = np.random.choice(range(len(data)))
ax[j][k].set_axis_off()
ax[j][k].imshow(data[i:i+1][0])
i+=1
plt.show()
def plot_img(self,data):
if self.type !='mnist':
assert data.shape == (3072,)
data = data.reshape(1,3072)
data = data.reshape(data.shape[0],3,32,32).transpose(0,2,3,1).astype("uint8")
elif self.type == 'mnist':
assert data.shape == (28*28,)
data = data.reshape(1,28,28).astype('uint8')
fig, ax = plt.subplots(figsize=(5,5))
ax.imshow(data[0])
plt.show()
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict发布于 2020-02-21 00:48:52
这是我在最后所做的,它从keras.datasets import cifar10 import matplotlib.pyplot as plt import cv2工作
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
for i in range(0,50000):
hsv_image = cv2.cvtColor(x_train[i] , cv2.COLOR_RGB2HSV)
hue , sat , val = hsv_image [:,:, 0 ], hsv_image [:,:, 1 ],
hsv_image [:,: , 2 ]
import numpy as np
plt.figure(figsize=(10,8))
plt.subplot(311) #plot in the first cell
plt.subplots_adjust(hspace=.5)
plt.title("Hue")
plt.hist(np.ndarray.flatten(hue), bins=8)
plt.subplot(312) #plot in the second cell
plt.title("Saturation")
plt.hist(np.ndarray.flatten(sat), bins=4)
plt.subplot(313) #plot in the third cell
plt.title("Luminosity Value")
plt.hist(np.ndarray.flatten(val), bins=2)
plt.show()https://stackoverflow.com/questions/60290978
复制相似问题