首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >解析cifar-10并将imgs改为HSV

解析cifar-10并将imgs改为HSV
EN

Stack Overflow用户
提问于 2020-02-19 07:38:06
回答 1查看 122关注 0票数 0

我正在使用此代码读取cifar-10数据集,并希望找到一种方法来更改它的HSV。我试着将数据和标签列表放在函数之外,但得到了这个错误:UnboundLocalError: local variable 'data' referenced before assignment

如何提取这些列表,以便可以更改为HSV,然后更改为数据集的直方图。

代码语言:javascript
复制
import pickle
import numpy as np
from os.path import join
from os import listdir
import matplotlib.pyplot as plt
from tqdm import tqdm
import struct as st

class DataReader:

    def __init__(self,root_dir,type='cifar-100'):
        self.root_dir = root_dir
        self.type = type

    def get_dict_from_pickle(self):
            self.train_dict = unpickle(join(self.root_dir,'train'))
            self.test_dict = unpickle(join(self.root_dir,'test'))

    def get_train_data(self):
        if self.type == 'cifar-100':
            self.get_dict_from_pickle()
            data = np.array(self.train_dict[b'data'])
            lbls_sub = np.array(self.train_dict[b'fine_labels'])
            lbls_class = np.array(self.train_dict[b'coarse_labels'])
            return data,lbls_class,lbls_sub
        elif self.type == 'cifar-10':
            #data = []
            #labels = []
            print("Reading")
            for file_ in tqdm(listdir(self.root_dir)):
                if file_.split('_')[0] == 'data':
                    dict = unpickle(join(self.root_dir,file_))
                    data.extend(dict[b'data'])
                    labels.extend(dict[b'labels'])

            return np.array(data),np.array(labels),None
        elif self.type =='mnist':
            return self.read_mnist()

    def get_test_data(self):
        if self.type == 'cifar-100':
            self.get_dict_from_pickle()
            data = np.array(self.test_dict[b'data'])
            lbls_sub = np.array(self.test_dict[b'fine_labels'])
            lbls_class = np.array(self.test_dict[b'coarse_labels'])
            return data,lbls_class,lbls_sub
        elif self.type == 'cifar-10':
            data = np.empty(shape=(0,3072))
            labels = []
            for file_ in listdir(self.root_dir):
                if file_.split('_')[0] == 'test':
                    dict = unpickle(join(self.root_dir,file_))
                    data = np.vstack((data,dict[b'data']))
                    print(data[data.shape[0]-1])
                    labels.append(dict[b'labels'])
            return np.array(data),np.array(labels),None

    def reshape_to_plot(self,data):
        if self.type == 'mnist':
            return data.reshape(data.shape[0],28,28).astype("uint8")
        return data.reshape(data.shape[0],3,32,32).transpose(0,2,3,1).astype("uint8")

    def plot_imgs(self,in_data,n,random=False):
        data = np.array([d for d in in_data])
        data = self.reshape_to_plot(data)
        x1 = min(n//2,5)
        if x1 == 0:
            x1 = 1
        y1 = (n//x1)
        x = min(x1,y1)
        y = max(x1,y1)
        fig, ax = plt.subplots(x,y,figsize=(5,5))
        i=0
        for j in range(x):
            for k in range(y):
                if random:
                    i = np.random.choice(range(len(data)))
                ax[j][k].set_axis_off()
                ax[j][k].imshow(data[i:i+1][0])
                i+=1
        plt.show()

    def plot_img(self,data):
        if self.type !='mnist':
            assert data.shape == (3072,)
            data = data.reshape(1,3072)
            data = data.reshape(data.shape[0],3,32,32).transpose(0,2,3,1).astype("uint8")
        elif self.type == 'mnist':
            assert data.shape == (28*28,)
            data = data.reshape(1,28,28).astype('uint8')
        fig, ax = plt.subplots(figsize=(5,5))
        ax.imshow(data[0])
        plt.show()

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
EN

回答 1

Stack Overflow用户

发布于 2020-02-21 00:48:52

这是我在最后所做的,它从keras.datasets import cifar10 import matplotlib.pyplot as plt import cv2工作

代码语言:javascript
复制
       (x_train, y_train), (x_test, y_test) = cifar10.load_data()
       for i in range(0,50000):
           hsv_image = cv2.cvtColor(x_train[i] , cv2.COLOR_RGB2HSV)
           hue ,  sat ,  val  =  hsv_image [:,:, 0 ],  hsv_image [:,:, 1 ],  
           hsv_image [:,: , 2 ]
       import numpy as np

       plt.figure(figsize=(10,8))
       plt.subplot(311)                             #plot in the first cell
       plt.subplots_adjust(hspace=.5)
       plt.title("Hue")
       plt.hist(np.ndarray.flatten(hue), bins=8)
       plt.subplot(312)                             #plot in the second cell
       plt.title("Saturation")
       plt.hist(np.ndarray.flatten(sat), bins=4)
       plt.subplot(313)                             #plot in the third cell
       plt.title("Luminosity Value")
       plt.hist(np.ndarray.flatten(val), bins=2)
       plt.show()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60290978

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档