Chainer: не может классифицировать, обученная модель (x) выдает ошибку

SerbentD спросил: 28 марта 2018 в 03:15 в: python

Я использую Chainer для распознавания цифр. Я подготовил модель в базе данных MNIST. Однако я не могу по какой-то причине классифицировать один пример. Я предполагал, что я не выбрал правильный формат для примера, но я пробовал его разными способами, и я все еще не могу решить эту проблему. Извиняюсь, если это очевидно, я не ознакомлен с Python.

Это выглядит так:

Код:

import chainer
import chainer.functions as F
import chainer.links as L
import os
import sys
from chainer.training import extensions
import numpy as np
from tkinter.filedialog import askopenfilename
from PIL import Image
from chainer import serializers
from chainer.dataset import concat_examplesfrom chainer.backends import cuda
from chainer import Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainListCONST_RESUME = ''class Network(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(Network, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)
            self.l2 = L.Linear(None, n_units)
            self.l3 = L.Linear(None, n_out)    def __call__(self, x):
        h1 = F.sigmoid(self.l1(x))
        h2 = F.sigmoid(self.l2(h1))
        return self.l3(h2)def query_yes_no(question, default="no"):
    """Ask a yes/no question via raw_input() and return their answer.    "question" is a string that is presented to the user.
    "default" is the presumed answer if the user just hits <Enter>.
        It must be "yes" (the default), "no" or None (meaning
        an answer is required of the user).    The "answer" return value is True for "yes" or False for "no".
    """
    valid = {"yes": True, "y": True, "ye": True,
             "no": False, "n": False}
    if default is None:
        prompt = " [y/n] "
    elif default == "yes":
        prompt = " [Y/n] "
    elif default == "no":
        prompt = " [y/N] "
    else:
        raise ValueError("invalid default answer: '%s'" % default)    while True:
        sys.stdout.write(question + prompt)
        choice = input().lower()
        if default is not None and choice == '':
            return valid[default]
        elif choice in valid:
            return valid[choice]
        else:
            sys.stdout.write("Please respond with 'yes' or 'no' "
                             "(or 'y' or 'n').\n")    return np.argmax(y.data)def main():
    #file_list = [None]*10
    #for i in range(10):
    #    file_list[i] = open('data{}.txt'.format(i), 'rb')    print('MNIST digit recognition.')
    usr_in : str    model = L.Classifier(Network(100, 10))
    chainer.backends.cuda.get_device_from_id(0).use()
    model.to_gpu()    usr_in = input('Input (t) to train, (l) to load.')
    while len(usr_in) != 1 and (usr_in[0] != 'l' or usr_in[0] != 't'):
        print('Invalid input.')
        usr_in = input()    if usr_in[0] == 't':        optimizer = chainer.optimizers.Adam()
        optimizer.setup(model)        train, test = chainer.datasets.get_mnist()        train_iter = chainer.iterators.SerialIterator(train, batch_size=100, shuffle=True)
        test_iter = chainer.iterators.SerialIterator(test, 100, repeat=False, shuffle=False)        updater = training.updaters.StandardUpdater(train_iter, optimizer, device=0)
        trainer = training.Trainer(updater, (10, 'epoch'), out='out.txt')        trainer.extend(extensions.Evaluator(test_iter, model, device=0))        trainer.extend(extensions.dump_graph('main/loss'))        frequency = 1
        trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))        trainer.extend(extensions.LogReport())        if extensions.PlotReport.available():
            trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))
            trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))        trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar())        if CONST_RESUME:
            chainer.serializers.load_npz(CONST_RESUME, trainer)        trainer.run()        ans = query_yes_no('Would you like to save this network?')
        if ans:
            usr_in = input('Input filename: ')
            serializers.save_npz('{}.{}'.format(usr_in, 'npz'), model)    elif usr_in[0] == 'l':
        filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file')
        serializers.load_npz(filename, model)    else:
        return    while True:
        ans = query_yes_no('Would you like to evaluate an image with the current network?')
        if ans:
            filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file')
            file = Image.open(filename)
            bw_file = file.convert('L')
            size = 28, 28
            bw_file.thumbnail(size, Image.ANTIALIAS)
            pix = bw_file.load()
            x = np.empty([28 * 28])
            for i in range(28):
                for j in range(28):
                    x[i * 28 + j] = pix[i, j]            #gpu_id = 0
            #batch = (, gpu_id)
            x = (x.astype(np.float32))[None, ...]            y = model(x)
            print('predicted_label:', y.argmax(axis=1)[0])        else:
            returnmain()

1 ответ

Есть решение
corochann ответил: 28 марта 2018 в 12:01

Можете ли вы показать нам сообщение об ошибке? В противном случае трудно догадаться, где причина ошибки.

Но я предполагаю, что форма ввода может отличаться. Когда вы используете встроенную функцию цепочки для получения набора данных, train, test = chainer.datasets.get_mnist() Это изображение набора данных форма (minibatch, channel, height, width). но кажется, что вы создаете входные данные x как форму (minibatch, height * width =28*28=784), которая является другой формой ??

Вы также можете обратиться к некоторому учебнику по цепочке,

  • работа с цепочками
  • учебник по углубленному обучению с помощью цепочки
    • код вывода MNIST