mirror of
https://github.com/curioustorvald/Terrarum-sans-bitmap.git
synced 2026-06-10 16:04:04 +09:00
revised autokem model
This commit is contained in:
@@ -128,12 +128,8 @@ static void save_weights(Network *net, Network *best) {
|
||||
copy_tensor_data(best->conv2.bias, net->conv2.bias);
|
||||
copy_tensor_data(best->fc1.weight, net->fc1.weight);
|
||||
copy_tensor_data(best->fc1.bias, net->fc1.bias);
|
||||
copy_tensor_data(best->head_shape.weight, net->head_shape.weight);
|
||||
copy_tensor_data(best->head_shape.bias, net->head_shape.bias);
|
||||
copy_tensor_data(best->head_ytype.weight, net->head_ytype.weight);
|
||||
copy_tensor_data(best->head_ytype.bias, net->head_ytype.bias);
|
||||
copy_tensor_data(best->head_lowheight.weight, net->head_lowheight.weight);
|
||||
copy_tensor_data(best->head_lowheight.bias, net->head_lowheight.bias);
|
||||
copy_tensor_data(best->output.weight, net->output.weight);
|
||||
copy_tensor_data(best->output.bias, net->output.bias);
|
||||
}
|
||||
|
||||
/* ---- Training ---- */
|
||||
@@ -247,19 +243,15 @@ int train_model(void) {
|
||||
int ishape[] = {bs, 1, 20, 15};
|
||||
Tensor *input = tensor_alloc(4, ishape);
|
||||
|
||||
int sshape[] = {bs, 10};
|
||||
Tensor *tgt_shape = tensor_alloc(2, sshape);
|
||||
|
||||
int yshape[] = {bs, 1};
|
||||
Tensor *tgt_ytype = tensor_alloc(2, yshape);
|
||||
Tensor *tgt_lh = tensor_alloc(2, yshape);
|
||||
int tshape[] = {bs, 12};
|
||||
Tensor *target = tensor_alloc(2, tshape);
|
||||
|
||||
for (int i = 0; i < bs; i++) {
|
||||
Sample *s = &all_samples[indices[start + i]];
|
||||
memcpy(input->data + i * 300, s->input, 300 * sizeof(float));
|
||||
memcpy(tgt_shape->data + i * 10, s->shape, 10 * sizeof(float));
|
||||
tgt_ytype->data[i] = s->ytype;
|
||||
tgt_lh->data[i] = s->lowheight;
|
||||
memcpy(target->data + i * 12, s->shape, 10 * sizeof(float));
|
||||
target->data[i * 12 + 10] = s->ytype;
|
||||
target->data[i * 12 + 11] = s->lowheight;
|
||||
}
|
||||
|
||||
/* Forward */
|
||||
@@ -267,21 +259,19 @@ int train_model(void) {
|
||||
network_forward(net, input, 1);
|
||||
|
||||
/* Loss */
|
||||
float loss = network_bce_loss(net, tgt_shape, tgt_ytype, tgt_lh);
|
||||
float loss = network_bce_loss(net, target);
|
||||
train_loss += loss;
|
||||
n_batches++;
|
||||
|
||||
/* Backward */
|
||||
network_backward(net, tgt_shape, tgt_ytype, tgt_lh);
|
||||
network_backward(net, target);
|
||||
|
||||
/* Adam step */
|
||||
adam_t++;
|
||||
network_adam_step(net, lr, beta1, beta2, eps, adam_t);
|
||||
|
||||
tensor_free(input);
|
||||
tensor_free(tgt_shape);
|
||||
tensor_free(tgt_ytype);
|
||||
tensor_free(tgt_lh);
|
||||
tensor_free(target);
|
||||
}
|
||||
|
||||
train_loss /= (float)n_batches;
|
||||
@@ -295,29 +285,23 @@ int train_model(void) {
|
||||
int ishape[] = {bs, 1, 20, 15};
|
||||
Tensor *input = tensor_alloc(4, ishape);
|
||||
|
||||
int sshape[] = {bs, 10};
|
||||
Tensor *tgt_shape = tensor_alloc(2, sshape);
|
||||
|
||||
int yshape[] = {bs, 1};
|
||||
Tensor *tgt_ytype = tensor_alloc(2, yshape);
|
||||
Tensor *tgt_lh = tensor_alloc(2, yshape);
|
||||
int tshape[] = {bs, 12};
|
||||
Tensor *target = tensor_alloc(2, tshape);
|
||||
|
||||
for (int i = 0; i < bs; i++) {
|
||||
Sample *s = &all_samples[indices[n_train + start + i]];
|
||||
memcpy(input->data + i * 300, s->input, 300 * sizeof(float));
|
||||
memcpy(tgt_shape->data + i * 10, s->shape, 10 * sizeof(float));
|
||||
tgt_ytype->data[i] = s->ytype;
|
||||
tgt_lh->data[i] = s->lowheight;
|
||||
memcpy(target->data + i * 12, s->shape, 10 * sizeof(float));
|
||||
target->data[i * 12 + 10] = s->ytype;
|
||||
target->data[i * 12 + 11] = s->lowheight;
|
||||
}
|
||||
|
||||
network_forward(net, input, 0);
|
||||
val_loss += network_bce_loss(net, tgt_shape, tgt_ytype, tgt_lh);
|
||||
val_loss += network_bce_loss(net, target);
|
||||
val_batches++;
|
||||
|
||||
tensor_free(input);
|
||||
tensor_free(tgt_shape);
|
||||
tensor_free(tgt_ytype);
|
||||
tensor_free(tgt_lh);
|
||||
tensor_free(target);
|
||||
}
|
||||
|
||||
val_loss /= (float)val_batches;
|
||||
|
||||
Reference in New Issue
Block a user