Files
ia2005/backprop/.#reseau.cpp.1.4
2024-11-12 17:41:10 +01:00

95 lines
2.1 KiB
Groff

#include "reseau.h"
reseau::reseau(int In, int Hid, int Out)
{
Ocouche=OutputCouche(In,&Hcouche);
Hcouche=HiddenCouche(Hid,&Ocouche);
Icouche=InputCouche(Out,&Hcouche);
initshm();
}
reseau::~reseau()
{
}
void reseau::forward(bool input[],bool target[])
{
std::vector<bool> tmp;
/* on active les couches */
for(unsigned i=0;i<Icouche.getNumber();i++) tmp.push_back(input[i]);
Icouche.activate(tmp);
Hcouche.activate(Icouche);
Ocouche.activate();
/* on place le resultat dans la shm */
for(unsigned i=0;i<Ocouche.getNumber();i++) target[i]=util.accept(Ocouche[i].getWeight());
}
void reseau::backward(bool input[], bool target[])
{
std::vector<double> hidDelta;
std::vector<double> outDelta;
/* Calcul des delta pour la couche OUPUT */
for(unsigned i=0;i<Ocouche.getNumber();i++)
{
outDelta.push_back(util.dsigmoid((double)target[i] - Ocouche[i].getWeight()));
}
/* Calcul des delta pour la couche HIDDEN */
for(unsigned i=0;i<Hcouche.getNumber();i++)
{
hidDelta.push_back(util.dsigmoid((double)target[i] - Hcouche[i].getWeight()));
}
Icouche.backPropagate(hidDelta);
Ocouche.backPropagate(outDelta);
}
void reseau::initshm()
{
if ((shmid = shmget(SHMKEY, sizeof(struct shmdata), 0666)) < 0)
{
perror("Unable to get shm id \n");
exit(1);
}
if ((SData = (struct shmdata *)shmat(shmid, NULL, 0)) == (struct shmdata *) -1)
{
perror("Unable to attach shm segment\n");
exit(1);
}
}
double reseau::getError(bool target[])
{
double error=0.0;
for(unsigned i=0;i<Ocouche.getNumber();i++)
{
error += pow(((double)target[i] - Ocouche[i].getWeight()),2);
}
return sqrt(error);
}
double reseau::learnOne(bool input[], bool target[])
{
double error=0;
while (error>LEARNACCEPT)
{
forward(input,target);
backward(input,target);
error=getError(target);
cout << "erreur : " << error << endl ;
}
return error;
}
double reseau::learnAll(std::vector<bool *> inputs, std::vector<bool *> targets)
{
double error=0.0;
unsigned i;
for(i=0;i<inputs.size();i++)
{
error+=learnOne(inputs[i],targets[i]);
}
return (double)(error/i);
}
/* END */