diff --git a/Autokem/CLAUDE.md b/Autokem/CLAUDE.md index 16b1e10..8884a33 100644 --- a/Autokem/CLAUDE.md +++ b/Autokem/CLAUDE.md @@ -25,20 +25,32 @@ make clean - `apply` creates `.bak` backup, runs inference per cell, writes Y+5 (lowheight) and Y+6 (kern data) pixels. Skips cells with width=0, writeOnTop, or compiler directives - Model file `autokem.safetensors` must be in the working directory +### PyTorch training (faster prototyping) + +```bash +cd Autokem +.venv/bin/python train_torch.py # train with defaults +.venv/bin/python train_torch.py --epochs 300 # override max epochs +.venv/bin/python train_torch.py --lr 0.0005 # override learning rate +.venv/bin/python train_torch.py --load model.safetensors # resume from weights +``` + +- Drop-in replacement for `./autokem train` — reads the same sheets, produces the same safetensors format +- The exported `autokem.safetensors` is directly loadable by the C inference code (`./autokem apply`) +- Requires: `pip install torch numpy` (venv at `.venv/`) + ## Architecture ### Neural network ``` Input: 15x20x1 binary (300 values, alpha >= 0x80 → 1.0) - Conv2D(1→12, 3x3, same) → LeakyReLU(0.01) - Conv2D(12→16, 3x3, same) → LeakyReLU(0.01) - Flatten → 4800 - Dense(4800→24) → LeakyReLU(0.01) - ├── Dense(24→10) → sigmoid (shape bits A-H, J, K) - ├── Dense(24→1) → sigmoid (Y-type) - └── Dense(24→1) → sigmoid (lowheight) -Total: ~117,388 params (~460 KB float32) + Conv2D(1→32, 7x7, pad=1) → SiLU + Conv2D(32→64, 7x7, pad=1) → SiLU + Global Average Pool → [batch, 64] + Dense(64→256) → SiLU + Dense(256→12) → sigmoid (10 shape bits + 1 ytype + 1 lowheight) +Total: ~121,740 params (~476 KB float32) ``` Training: Adam (lr=0.001, beta1=0.9, beta2=0.999), BCE loss, batch size 32, early stopping patience 10. @@ -49,8 +61,9 @@ Training: Adam (lr=0.001, beta1=0.9, beta2=0.999), BCE loss, batch size 32, earl |------|---------| | `main.c` | CLI dispatch | | `tga.h/tga.c` | TGA reader/writer — BGRA↔RGBA8888, row-order handling, per-pixel write-in-place | -| `nn.h/nn.c` | Tensor, Conv2D (same padding), Dense, LeakyReLU, sigmoid, Adam, He init | -| `safetensor.h/safetensor.c` | `.safetensors` serialisation — 12 named tensors + JSON metadata | +| `nn.h/nn.c` | Tensor, Conv2D (configurable padding), Dense, SiLU, sigmoid, global avg pool, Adam, He init | +| `safetensor.h/safetensor.c` | `.safetensors` serialisation — 8 named tensors + JSON metadata | +| `train_torch.py` | PyTorch training script — same data pipeline and architecture, exports C-compatible safetensors | | `train.h/train.c` | Data collection from sheets, training loop, validation, label distribution | | `apply.h/apply.c` | Backup, eligibility checks, inference, pixel composition | diff --git a/Autokem/autokem.safetensors b/Autokem/autokem.safetensors index 26d3365..22feab6 100644 --- a/Autokem/autokem.safetensors +++ b/Autokem/autokem.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1642bc950d7a027953e5efa3c2e8806ac2b62250078997c4bccd2e4b192f915f -size 470552 +oid sha256:7fea59332d8e12ad664b8bcf7dcd7f538237da7d37a9c07fdadfdfe245736b49 +size 487640 diff --git a/Autokem/eval.sh b/Autokem/eval.sh new file mode 100755 index 0000000..29d398e --- /dev/null +++ b/Autokem/eval.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# Run train_torch.py N times and report mean ± stddev of per-bit and overall accuracy. +# Usage: ./eval.sh [runs] [extra train_torch.py args...] +# e.g. ./eval.sh 10 +# ./eval.sh 5 --epochs 300 --lr 0.0005 + +set -euo pipefail +cd "$(dirname "$0")" + +RUNS="${1:-10}" +shift 2>/dev/null || true +EXTRA_ARGS="$*" +PYTHON="${PYTHON:-.venv/bin/python3}" +RESULTS_FILE=$(mktemp) + +trap 'rm -f "$RESULTS_FILE"' EXIT + +echo "=== Autokem evaluation: $RUNS runs ===" +[ -n "$EXTRA_ARGS" ] && echo "Extra args: $EXTRA_ARGS" +echo + +for i in $(seq 1 "$RUNS"); do + echo "--- Run $i/$RUNS ---" + OUT=$("$PYTHON" train_torch.py --save /dev/null $EXTRA_ARGS 2>&1) + + # Extract per-bit line (the one after "Per-bit accuracy"): A:53.9% B:46.7% ... + PERBIT=$(echo "$OUT" | grep -A1 'Per-bit accuracy' | tail -1) + # Extract overall line: Overall: 5267/6660 (79.08%) + OVERALL=$(echo "$OUT" | grep -oP 'Overall:.*\(\K[0-9.]+') + # Extract val_loss + VALLOSS=$(echo "$OUT" | grep -oP 'val_loss: \K[0-9.]+' | tail -1) + + # Parse per-bit percentages into a tab-separated line + BITS=$(echo "$PERBIT" | grep -oP '[0-9.]+(?=%)' | tr '\n' '\t') + + echo "$BITS$OVERALL $VALLOSS" >> "$RESULTS_FILE" + echo " val_loss=$VALLOSS overall=$OVERALL%" +done + +echo +echo "=== Results ($RUNS runs) ===" + +"$PYTHON" - "$RESULTS_FILE" <<'PYEOF' +import sys +import numpy as np + +names = ['A','B','C','D','E','F','G','H','J','K','Ytype','LowH','Overall','ValLoss'] +data = [] +with open(sys.argv[1]) as f: + for line in f: + vals = line.strip().split('\t') + if len(vals) >= len(names): + data.append([float(v) for v in vals[:len(names)]]) + +if not data: + print("No data collected!") + sys.exit(1) + +arr = np.array(data) +means = arr.mean(axis=0) +stds = arr.std(axis=0) + +print(f"{'Metric':<10s} {'Mean':>8s} {'StdDev':>8s}") +print("-" * 28) +for i, name in enumerate(names): + unit = '' if name == 'ValLoss' else '%' + print(f"{name:<10s} {means[i]:>7.2f}{unit} {stds[i]:>7.2f}{unit}") +PYEOF diff --git a/Autokem/nn.c b/Autokem/nn.c index a7b9cbf..6383eac 100644 --- a/Autokem/nn.c +++ b/Autokem/nn.c @@ -71,14 +71,6 @@ static void he_init(Tensor *w, int fan_in) { /* ---- Activations ---- */ -static inline float leaky_relu(float x) { - return x >= 0.0f ? x : 0.01f * x; -} - -static inline float leaky_relu_grad(float x) { - return x >= 0.0f ? 1.0f : 0.01f; -} - static inline float sigmoid_f(float x) { if (x >= 0.0f) { float ez = expf(-x); @@ -89,13 +81,24 @@ static inline float sigmoid_f(float x) { } } +static inline float silu_f(float x) { + return x * sigmoid_f(x); +} + +static inline float silu_grad(float x) { + float s = sigmoid_f(x); + return s * (1.0f + x * (1.0f - s)); +} + /* ---- Conv2D forward/backward ---- */ -static void conv2d_init(Conv2D *c, int in_ch, int out_ch, int kh, int kw) { +static void conv2d_init(Conv2D *c, int in_ch, int out_ch, int kh, int kw, int pad) { c->in_ch = in_ch; c->out_ch = out_ch; c->kh = kh; c->kw = kw; + c->pad_h = pad; + c->pad_w = pad; int wshape[] = {out_ch, in_ch, kh, kw}; int bshape[] = {out_ch}; @@ -125,13 +128,13 @@ static void conv2d_free(Conv2D *c) { tensor_free(c->input_cache); } -/* Forward: input [batch, in_ch, H, W] -> output [batch, out_ch, H, W] (same padding) */ +/* Forward: input [batch, in_ch, H, W] -> output [batch, out_ch, oH, oW] */ static Tensor *conv2d_forward(Conv2D *c, Tensor *input, int training) { int batch = input->shape[0]; int in_ch = c->in_ch, out_ch = c->out_ch; int H = input->shape[2], W = input->shape[3]; int kh = c->kh, kw = c->kw; - int ph = kh / 2, pw = kw / 2; + int ph = c->pad_h, pw = c->pad_w; if (training) { tensor_free(c->input_cache); @@ -139,13 +142,15 @@ static Tensor *conv2d_forward(Conv2D *c, Tensor *input, int training) { memcpy(c->input_cache->data, input->data, (size_t)input->size * sizeof(float)); } - int oshape[] = {batch, out_ch, H, W}; + int oH = H + 2 * ph - kh + 1; + int oW = W + 2 * pw - kw + 1; + int oshape[] = {batch, out_ch, oH, oW}; Tensor *out = tensor_alloc(4, oshape); for (int b = 0; b < batch; b++) { for (int oc = 0; oc < out_ch; oc++) { - for (int oh = 0; oh < H; oh++) { - for (int ow = 0; ow < W; ow++) { + for (int oh = 0; oh < oH; oh++) { + for (int ow = 0; ow < oW; ow++) { float sum = c->bias->data[oc]; for (int ic = 0; ic < in_ch; ic++) { for (int fh = 0; fh < kh; fh++) { @@ -160,7 +165,7 @@ static Tensor *conv2d_forward(Conv2D *c, Tensor *input, int training) { } } } - out->data[((b * out_ch + oc) * H + oh) * W + ow] = sum; + out->data[((b * out_ch + oc) * oH + oh) * oW + ow] = sum; } } } @@ -168,22 +173,23 @@ static Tensor *conv2d_forward(Conv2D *c, Tensor *input, int training) { return out; } -/* Backward: grad_output [batch, out_ch, H, W] -> grad_input [batch, in_ch, H, W] */ +/* Backward: grad_output [batch, out_ch, oH, oW] -> grad_input [batch, in_ch, H, W] */ static Tensor *conv2d_backward(Conv2D *c, Tensor *grad_output) { Tensor *input = c->input_cache; int batch = input->shape[0]; int in_ch = c->in_ch, out_ch = c->out_ch; int H = input->shape[2], W = input->shape[3]; int kh = c->kh, kw = c->kw; - int ph = kh / 2, pw = kw / 2; + int ph = c->pad_h, pw = c->pad_w; + int oH = grad_output->shape[2], oW = grad_output->shape[3]; Tensor *grad_input = tensor_zeros(input->ndim, input->shape); for (int b = 0; b < batch; b++) { for (int oc = 0; oc < out_ch; oc++) { - for (int oh = 0; oh < H; oh++) { - for (int ow = 0; ow < W; ow++) { - float go = grad_output->data[((b * out_ch + oc) * H + oh) * W + ow]; + for (int oh = 0; oh < oH; oh++) { + for (int ow = 0; ow < oW; ow++) { + float go = grad_output->data[((b * out_ch + oc) * oH + oh) * oW + ow]; c->grad_bias->data[oc] += go; for (int ic = 0; ic < in_ch; ic++) { for (int fh = 0; fh < kh; fh++) { @@ -288,22 +294,68 @@ static Tensor *dense_backward(Dense *d, Tensor *grad_output) { return grad_input; } -/* ---- LeakyReLU helpers on tensors ---- */ +/* ---- SiLU helpers on tensors ---- */ -static Tensor *apply_leaky_relu(Tensor *input) { +static Tensor *apply_silu(Tensor *input) { Tensor *out = tensor_alloc(input->ndim, input->shape); for (int i = 0; i < input->size; i++) - out->data[i] = leaky_relu(input->data[i]); + out->data[i] = silu_f(input->data[i]); return out; } -static Tensor *apply_leaky_relu_backward(Tensor *grad_output, Tensor *pre_activation) { +static Tensor *apply_silu_backward(Tensor *grad_output, Tensor *pre_activation) { Tensor *grad = tensor_alloc(grad_output->ndim, grad_output->shape); for (int i = 0; i < grad_output->size; i++) - grad->data[i] = grad_output->data[i] * leaky_relu_grad(pre_activation->data[i]); + grad->data[i] = grad_output->data[i] * silu_grad(pre_activation->data[i]); return grad; } +/* ---- Global Average Pooling ---- */ + +/* Forward: input [batch, C, H, W] -> output [batch, C] */ +static Tensor *global_avg_pool_forward(Tensor *input) { + int batch = input->shape[0]; + int C = input->shape[1]; + int H = input->shape[2]; + int W = input->shape[3]; + int hw = H * W; + + int oshape[] = {batch, C}; + Tensor *out = tensor_alloc(2, oshape); + + for (int b = 0; b < batch; b++) { + for (int c = 0; c < C; c++) { + float sum = 0.0f; + int base = (b * C + c) * hw; + for (int i = 0; i < hw; i++) + sum += input->data[base + i]; + out->data[b * C + c] = sum / (float)hw; + } + } + return out; +} + +/* Backward: grad_output [batch, C] -> grad_input [batch, C, H, W] */ +static Tensor *global_avg_pool_backward(Tensor *grad_output, int H, int W) { + int batch = grad_output->shape[0]; + int C = grad_output->shape[1]; + int hw = H * W; + float scale = 1.0f / (float)hw; + + int ishape[] = {batch, C, H, W}; + Tensor *grad_input = tensor_alloc(4, ishape); + + for (int b = 0; b < batch; b++) { + for (int c = 0; c < C; c++) { + float go = grad_output->data[b * C + c] * scale; + int base = (b * C + c) * hw; + for (int i = 0; i < hw; i++) + grad_input->data[base + i] = go; + } + } + return grad_input; +} + /* ---- Sigmoid on tensor ---- */ static Tensor *apply_sigmoid(Tensor *input) { @@ -335,12 +387,10 @@ Network *network_create(void) { rng_seed((uint64_t)time(NULL) ^ 0xDEADBEEF); Network *net = calloc(1, sizeof(Network)); - conv2d_init(&net->conv1, 1, 12, 3, 3); - conv2d_init(&net->conv2, 12, 16, 3, 3); - dense_init(&net->fc1, 4800, 24); - dense_init(&net->head_shape, 24, 10); - dense_init(&net->head_ytype, 24, 1); - dense_init(&net->head_lowheight, 24, 1); + conv2d_init(&net->conv1, 1, 32, 7, 7, 1); + conv2d_init(&net->conv2, 32, 64, 7, 7, 1); + dense_init(&net->fc1, 64, 256); + dense_init(&net->output, 256, 12); return net; } @@ -349,133 +399,92 @@ void network_free(Network *net) { conv2d_free(&net->conv1); conv2d_free(&net->conv2); dense_free(&net->fc1); - dense_free(&net->head_shape); - dense_free(&net->head_ytype); - dense_free(&net->head_lowheight); + dense_free(&net->output); tensor_free(net->act_conv1); - tensor_free(net->act_relu1); + tensor_free(net->act_silu1); tensor_free(net->act_conv2); - tensor_free(net->act_relu2); - tensor_free(net->act_flat); + tensor_free(net->act_silu2); + tensor_free(net->act_pool); tensor_free(net->act_fc1); - tensor_free(net->act_relu3); - tensor_free(net->out_shape); - tensor_free(net->out_ytype); - tensor_free(net->out_lowheight); + tensor_free(net->act_silu3); + tensor_free(net->act_logits); + tensor_free(net->out_all); free(net); } static void free_activations(Network *net) { tensor_free(net->act_conv1); net->act_conv1 = NULL; - tensor_free(net->act_relu1); net->act_relu1 = NULL; + tensor_free(net->act_silu1); net->act_silu1 = NULL; tensor_free(net->act_conv2); net->act_conv2 = NULL; - tensor_free(net->act_relu2); net->act_relu2 = NULL; - tensor_free(net->act_flat); net->act_flat = NULL; + tensor_free(net->act_silu2); net->act_silu2 = NULL; + tensor_free(net->act_pool); net->act_pool = NULL; tensor_free(net->act_fc1); net->act_fc1 = NULL; - tensor_free(net->act_relu3); net->act_relu3 = NULL; - tensor_free(net->out_shape); net->out_shape = NULL; - tensor_free(net->out_ytype); net->out_ytype = NULL; - tensor_free(net->out_lowheight); net->out_lowheight = NULL; + tensor_free(net->act_silu3); net->act_silu3 = NULL; + tensor_free(net->act_logits); net->act_logits = NULL; + tensor_free(net->out_all); net->out_all = NULL; } void network_forward(Network *net, Tensor *input, int training) { free_activations(net); - /* Conv1 -> LeakyReLU */ + /* Conv1 -> SiLU */ net->act_conv1 = conv2d_forward(&net->conv1, input, training); - net->act_relu1 = apply_leaky_relu(net->act_conv1); + net->act_silu1 = apply_silu(net->act_conv1); - /* Conv2 -> LeakyReLU */ - net->act_conv2 = conv2d_forward(&net->conv2, net->act_relu1, training); - net->act_relu2 = apply_leaky_relu(net->act_conv2); + /* Conv2 -> SiLU */ + net->act_conv2 = conv2d_forward(&net->conv2, net->act_silu1, training); + net->act_silu2 = apply_silu(net->act_conv2); - /* Flatten: [batch, 16, 20, 15] -> [batch, 4800] */ - int batch = net->act_relu2->shape[0]; - int flat_size = net->act_relu2->size / batch; - int fshape[] = {batch, flat_size}; - net->act_flat = tensor_alloc(2, fshape); - memcpy(net->act_flat->data, net->act_relu2->data, (size_t)net->act_relu2->size * sizeof(float)); + /* Global Average Pool */ + net->act_pool = global_avg_pool_forward(net->act_silu2); - /* FC1 -> LeakyReLU */ - net->act_fc1 = dense_forward(&net->fc1, net->act_flat, training); - net->act_relu3 = apply_leaky_relu(net->act_fc1); + /* FC1 -> SiLU */ + net->act_fc1 = dense_forward(&net->fc1, net->act_pool, training); + net->act_silu3 = apply_silu(net->act_fc1); - /* Three heads with sigmoid */ - Tensor *logit_shape = dense_forward(&net->head_shape, net->act_relu3, training); - Tensor *logit_ytype = dense_forward(&net->head_ytype, net->act_relu3, training); - Tensor *logit_lowheight = dense_forward(&net->head_lowheight, net->act_relu3, training); - - net->out_shape = apply_sigmoid(logit_shape); - net->out_ytype = apply_sigmoid(logit_ytype); - net->out_lowheight = apply_sigmoid(logit_lowheight); - - tensor_free(logit_shape); - tensor_free(logit_ytype); - tensor_free(logit_lowheight); + /* Output -> Sigmoid */ + net->act_logits = dense_forward(&net->output, net->act_silu3, training); + net->out_all = apply_sigmoid(net->act_logits); } -void network_backward(Network *net, Tensor *target_shape, Tensor *target_ytype, Tensor *target_lowheight) { - int batch = net->out_shape->shape[0]; +void network_backward(Network *net, Tensor *target) { + int batch = net->out_all->shape[0]; + int n_out = 12; - /* BCE gradient at sigmoid: d_logit = pred - target */ - /* Head: shape (10 outputs) */ - int gs[] = {batch, 10}; - Tensor *grad_logit_shape = tensor_alloc(2, gs); - for (int i = 0; i < batch * 10; i++) - grad_logit_shape->data[i] = (net->out_shape->data[i] - target_shape->data[i]) / (float)batch; + /* BCE gradient at sigmoid: d_logit = (pred - target) / batch */ + int gs[] = {batch, n_out}; + Tensor *grad_logits = tensor_alloc(2, gs); + for (int i = 0; i < batch * n_out; i++) + grad_logits->data[i] = (net->out_all->data[i] - target->data[i]) / (float)batch; - int gy[] = {batch, 1}; - Tensor *grad_logit_ytype = tensor_alloc(2, gy); - for (int i = 0; i < batch; i++) - grad_logit_ytype->data[i] = (net->out_ytype->data[i] - target_ytype->data[i]) / (float)batch; + /* Output layer backward */ + Tensor *grad_silu3 = dense_backward(&net->output, grad_logits); + tensor_free(grad_logits); - Tensor *grad_logit_lh = tensor_alloc(2, gy); - for (int i = 0; i < batch; i++) - grad_logit_lh->data[i] = (net->out_lowheight->data[i] - target_lowheight->data[i]) / (float)batch; + /* SiLU backward (fc1) */ + Tensor *grad_fc1_out = apply_silu_backward(grad_silu3, net->act_fc1); + tensor_free(grad_silu3); - /* Backward through heads */ - Tensor *grad_relu3_s = dense_backward(&net->head_shape, grad_logit_shape); - Tensor *grad_relu3_y = dense_backward(&net->head_ytype, grad_logit_ytype); - Tensor *grad_relu3_l = dense_backward(&net->head_lowheight, grad_logit_lh); - - /* Sum gradients from three heads */ - int r3shape[] = {batch, 24}; - Tensor *grad_relu3 = tensor_zeros(2, r3shape); - for (int i = 0; i < batch * 24; i++) - grad_relu3->data[i] = grad_relu3_s->data[i] + grad_relu3_y->data[i] + grad_relu3_l->data[i]; - - tensor_free(grad_logit_shape); - tensor_free(grad_logit_ytype); - tensor_free(grad_logit_lh); - tensor_free(grad_relu3_s); - tensor_free(grad_relu3_y); - tensor_free(grad_relu3_l); - - /* LeakyReLU backward (fc1 output) */ - Tensor *grad_fc1_out = apply_leaky_relu_backward(grad_relu3, net->act_fc1); - tensor_free(grad_relu3); - - /* Dense fc1 backward */ - Tensor *grad_flat = dense_backward(&net->fc1, grad_fc1_out); + /* FC1 backward */ + Tensor *grad_pool = dense_backward(&net->fc1, grad_fc1_out); tensor_free(grad_fc1_out); - /* Unflatten: [batch, 4800] -> [batch, 16, 20, 15] */ - int ushape[] = {batch, 16, 20, 15}; - Tensor *grad_relu2 = tensor_alloc(4, ushape); - memcpy(grad_relu2->data, grad_flat->data, (size_t)grad_flat->size * sizeof(float)); - tensor_free(grad_flat); + /* Global Average Pool backward */ + int H = net->act_silu2->shape[2], W = net->act_silu2->shape[3]; + Tensor *grad_silu2 = global_avg_pool_backward(grad_pool, H, W); + tensor_free(grad_pool); - /* LeakyReLU backward (conv2 output) */ - Tensor *grad_conv2_out = apply_leaky_relu_backward(grad_relu2, net->act_conv2); - tensor_free(grad_relu2); + /* SiLU backward (conv2) */ + Tensor *grad_conv2_out = apply_silu_backward(grad_silu2, net->act_conv2); + tensor_free(grad_silu2); /* Conv2 backward */ - Tensor *grad_relu1 = conv2d_backward(&net->conv2, grad_conv2_out); + Tensor *grad_silu1 = conv2d_backward(&net->conv2, grad_conv2_out); tensor_free(grad_conv2_out); - /* LeakyReLU backward (conv1 output) */ - Tensor *grad_conv1_out = apply_leaky_relu_backward(grad_relu1, net->act_conv1); - tensor_free(grad_relu1); + /* SiLU backward (conv1) */ + Tensor *grad_conv1_out = apply_silu_backward(grad_silu1, net->act_conv1); + tensor_free(grad_silu1); /* Conv1 backward */ Tensor *grad_input = conv2d_backward(&net->conv1, grad_conv1_out); @@ -490,12 +499,8 @@ void network_adam_step(Network *net, float lr, float beta1, float beta2, float e adam_update(net->conv2.bias, net->conv2.grad_bias, net->conv2.m_bias, net->conv2.v_bias, lr, beta1, beta2, eps, t); adam_update(net->fc1.weight, net->fc1.grad_weight, net->fc1.m_weight, net->fc1.v_weight, lr, beta1, beta2, eps, t); adam_update(net->fc1.bias, net->fc1.grad_bias, net->fc1.m_bias, net->fc1.v_bias, lr, beta1, beta2, eps, t); - adam_update(net->head_shape.weight, net->head_shape.grad_weight, net->head_shape.m_weight, net->head_shape.v_weight, lr, beta1, beta2, eps, t); - adam_update(net->head_shape.bias, net->head_shape.grad_bias, net->head_shape.m_bias, net->head_shape.v_bias, lr, beta1, beta2, eps, t); - adam_update(net->head_ytype.weight, net->head_ytype.grad_weight, net->head_ytype.m_weight, net->head_ytype.v_weight, lr, beta1, beta2, eps, t); - adam_update(net->head_ytype.bias, net->head_ytype.grad_bias, net->head_ytype.m_bias, net->head_ytype.v_bias, lr, beta1, beta2, eps, t); - adam_update(net->head_lowheight.weight, net->head_lowheight.grad_weight, net->head_lowheight.m_weight, net->head_lowheight.v_weight, lr, beta1, beta2, eps, t); - adam_update(net->head_lowheight.bias, net->head_lowheight.grad_bias, net->head_lowheight.m_bias, net->head_lowheight.v_bias, lr, beta1, beta2, eps, t); + adam_update(net->output.weight, net->output.grad_weight, net->output.m_weight, net->output.v_weight, lr, beta1, beta2, eps, t); + adam_update(net->output.bias, net->output.grad_bias, net->output.m_bias, net->output.v_bias, lr, beta1, beta2, eps, t); } void network_zero_grad(Network *net) { @@ -505,34 +510,18 @@ void network_zero_grad(Network *net) { memset(net->conv2.grad_bias->data, 0, (size_t)net->conv2.grad_bias->size * sizeof(float)); memset(net->fc1.grad_weight->data, 0, (size_t)net->fc1.grad_weight->size * sizeof(float)); memset(net->fc1.grad_bias->data, 0, (size_t)net->fc1.grad_bias->size * sizeof(float)); - memset(net->head_shape.grad_weight->data, 0, (size_t)net->head_shape.grad_weight->size * sizeof(float)); - memset(net->head_shape.grad_bias->data, 0, (size_t)net->head_shape.grad_bias->size * sizeof(float)); - memset(net->head_ytype.grad_weight->data, 0, (size_t)net->head_ytype.grad_weight->size * sizeof(float)); - memset(net->head_ytype.grad_bias->data, 0, (size_t)net->head_ytype.grad_bias->size * sizeof(float)); - memset(net->head_lowheight.grad_weight->data, 0, (size_t)net->head_lowheight.grad_weight->size * sizeof(float)); - memset(net->head_lowheight.grad_bias->data, 0, (size_t)net->head_lowheight.grad_bias->size * sizeof(float)); + memset(net->output.grad_weight->data, 0, (size_t)net->output.grad_weight->size * sizeof(float)); + memset(net->output.grad_bias->data, 0, (size_t)net->output.grad_bias->size * sizeof(float)); } -float network_bce_loss(Network *net, Tensor *target_shape, Tensor *target_ytype, Tensor *target_lowheight) { +float network_bce_loss(Network *net, Tensor *target) { float loss = 0.0f; - int batch = net->out_shape->shape[0]; + int batch = net->out_all->shape[0]; + int n = batch * 12; - for (int i = 0; i < batch * 10; i++) { - float p = net->out_shape->data[i]; - float t = target_shape->data[i]; - p = fmaxf(1e-7f, fminf(1.0f - 1e-7f, p)); - loss -= t * logf(p) + (1.0f - t) * logf(1.0f - p); - } - for (int i = 0; i < batch; i++) { - float p = net->out_ytype->data[i]; - float t = target_ytype->data[i]; - p = fmaxf(1e-7f, fminf(1.0f - 1e-7f, p)); - loss -= t * logf(p) + (1.0f - t) * logf(1.0f - p); - } - for (int i = 0; i < batch; i++) { - float p = net->out_lowheight->data[i]; - float t = target_lowheight->data[i]; - p = fmaxf(1e-7f, fminf(1.0f - 1e-7f, p)); + for (int i = 0; i < n; i++) { + float p = fmaxf(1e-7f, fminf(1.0f - 1e-7f, net->out_all->data[i])); + float t = target->data[i]; loss -= t * logf(p) + (1.0f - t) * logf(1.0f - p); } @@ -546,11 +535,8 @@ void network_infer(Network *net, const float *input300, float *output12) { network_forward(net, input, 0); - /* output order: A,B,C,D,E,F,G,H,J,K, ytype, lowheight */ - for (int i = 0; i < 10; i++) - output12[i] = net->out_shape->data[i]; - output12[10] = net->out_ytype->data[0]; - output12[11] = net->out_lowheight->data[0]; + for (int i = 0; i < 12; i++) + output12[i] = net->out_all->data[i]; tensor_free(input); } diff --git a/Autokem/nn.h b/Autokem/nn.h index 9769be5..5bebf96 100644 --- a/Autokem/nn.h +++ b/Autokem/nn.h @@ -20,6 +20,7 @@ void tensor_free(Tensor *t); typedef struct { int in_ch, out_ch, kh, kw; + int pad_h, pad_w; Tensor *weight; /* [out_ch, in_ch, kh, kw] */ Tensor *bias; /* [out_ch] */ Tensor *grad_weight; @@ -45,35 +46,32 @@ typedef struct { /* ---- Network ---- */ typedef struct { - Conv2D conv1; /* 1->12, 3x3 */ - Conv2D conv2; /* 12->16, 3x3 */ - Dense fc1; /* 4800->24 */ - Dense head_shape; /* 24->10 (bits A-H, J, K) */ - Dense head_ytype; /* 24->1 */ - Dense head_lowheight;/* 24->1 */ + Conv2D conv1; /* 1->32, 7x7, pad=1 */ + Conv2D conv2; /* 32->64, 7x7, pad=1 */ + Dense fc1; /* 64->256 */ + Dense output; /* 256->12 (10 shape + 1 ytype + 1 lowheight) */ /* activation caches (allocated per forward) */ Tensor *act_conv1; - Tensor *act_relu1; + Tensor *act_silu1; Tensor *act_conv2; - Tensor *act_relu2; - Tensor *act_flat; + Tensor *act_silu2; + Tensor *act_pool; /* global average pool output */ Tensor *act_fc1; - Tensor *act_relu3; - Tensor *out_shape; - Tensor *out_ytype; - Tensor *out_lowheight; + Tensor *act_silu3; + Tensor *act_logits; /* pre-sigmoid */ + Tensor *out_all; /* sigmoid output [batch, 12] */ } Network; /* Init / free */ Network *network_create(void); void network_free(Network *net); -/* Forward pass. input: [batch, 1, 20, 15]. Outputs stored in net->out_* */ +/* Forward pass. input: [batch, 1, 20, 15]. Output stored in net->out_all */ void network_forward(Network *net, Tensor *input, int training); -/* Backward pass. targets: shape[batch,10], ytype[batch,1], lowheight[batch,1] */ -void network_backward(Network *net, Tensor *target_shape, Tensor *target_ytype, Tensor *target_lowheight); +/* Backward pass. target: [batch, 12] */ +void network_backward(Network *net, Tensor *target); /* Adam update step */ void network_adam_step(Network *net, float lr, float beta1, float beta2, float eps, int t); @@ -81,8 +79,8 @@ void network_adam_step(Network *net, float lr, float beta1, float beta2, float e /* Zero all gradients */ void network_zero_grad(Network *net); -/* Compute BCE loss (sum of all heads) */ -float network_bce_loss(Network *net, Tensor *target_shape, Tensor *target_ytype, Tensor *target_lowheight); +/* Compute BCE loss */ +float network_bce_loss(Network *net, Tensor *target); /* Single-sample inference: input float[300], output float[12] (A-H,J,K,ytype,lowheight) */ void network_infer(Network *net, const float *input300, float *output12); diff --git a/Autokem/safetensor.c b/Autokem/safetensor.c index 6873901..b44e929 100644 --- a/Autokem/safetensor.c +++ b/Autokem/safetensor.c @@ -31,18 +31,14 @@ static void collect_tensors(Network *net, TensorEntry *entries, int *count) { ADD("conv2.bias", conv2, bias); ADD("fc1.weight", fc1, weight); ADD("fc1.bias", fc1, bias); - ADD("head_shape.weight", head_shape, weight); - ADD("head_shape.bias", head_shape, bias); - ADD("head_ytype.weight", head_ytype, weight); - ADD("head_ytype.bias", head_ytype, bias); - ADD("head_lowheight.weight", head_lowheight, weight); - ADD("head_lowheight.bias", head_lowheight, bias); + ADD("output.weight", output, weight); + ADD("output.bias", output, bias); #undef ADD *count = n; } int safetensor_save(const char *path, Network *net, int total_samples, int epochs, float val_loss) { - TensorEntry entries[12]; + TensorEntry entries[8]; int count; collect_tensors(net, entries, &count); @@ -149,7 +145,7 @@ int safetensor_load(const char *path, Network *net) { long data_start = 8 + (long)header_len; - TensorEntry entries[12]; + TensorEntry entries[8]; int count; collect_tensors(net, entries, &count); @@ -234,14 +230,12 @@ int safetensor_stats(const char *path) { const char *tensor_names[] = { "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "fc1.weight", "fc1.bias", - "head_shape.weight", "head_shape.bias", - "head_ytype.weight", "head_ytype.bias", - "head_lowheight.weight", "head_lowheight.bias" + "output.weight", "output.bias" }; int total_params = 0; printf("\nTensors:\n"); - for (int i = 0; i < 12; i++) { + for (int i = 0; i < 8; i++) { size_t off_start, off_end; if (find_tensor_offsets(json, (size_t)header_len, tensor_names[i], &off_start, &off_end) == 0) { int params = (int)(off_end - off_start) / 4; diff --git a/Autokem/train.c b/Autokem/train.c index f22b981..c647296 100644 --- a/Autokem/train.c +++ b/Autokem/train.c @@ -128,12 +128,8 @@ static void save_weights(Network *net, Network *best) { copy_tensor_data(best->conv2.bias, net->conv2.bias); copy_tensor_data(best->fc1.weight, net->fc1.weight); copy_tensor_data(best->fc1.bias, net->fc1.bias); - copy_tensor_data(best->head_shape.weight, net->head_shape.weight); - copy_tensor_data(best->head_shape.bias, net->head_shape.bias); - copy_tensor_data(best->head_ytype.weight, net->head_ytype.weight); - copy_tensor_data(best->head_ytype.bias, net->head_ytype.bias); - copy_tensor_data(best->head_lowheight.weight, net->head_lowheight.weight); - copy_tensor_data(best->head_lowheight.bias, net->head_lowheight.bias); + copy_tensor_data(best->output.weight, net->output.weight); + copy_tensor_data(best->output.bias, net->output.bias); } /* ---- Training ---- */ @@ -247,19 +243,15 @@ int train_model(void) { int ishape[] = {bs, 1, 20, 15}; Tensor *input = tensor_alloc(4, ishape); - int sshape[] = {bs, 10}; - Tensor *tgt_shape = tensor_alloc(2, sshape); - - int yshape[] = {bs, 1}; - Tensor *tgt_ytype = tensor_alloc(2, yshape); - Tensor *tgt_lh = tensor_alloc(2, yshape); + int tshape[] = {bs, 12}; + Tensor *target = tensor_alloc(2, tshape); for (int i = 0; i < bs; i++) { Sample *s = &all_samples[indices[start + i]]; memcpy(input->data + i * 300, s->input, 300 * sizeof(float)); - memcpy(tgt_shape->data + i * 10, s->shape, 10 * sizeof(float)); - tgt_ytype->data[i] = s->ytype; - tgt_lh->data[i] = s->lowheight; + memcpy(target->data + i * 12, s->shape, 10 * sizeof(float)); + target->data[i * 12 + 10] = s->ytype; + target->data[i * 12 + 11] = s->lowheight; } /* Forward */ @@ -267,21 +259,19 @@ int train_model(void) { network_forward(net, input, 1); /* Loss */ - float loss = network_bce_loss(net, tgt_shape, tgt_ytype, tgt_lh); + float loss = network_bce_loss(net, target); train_loss += loss; n_batches++; /* Backward */ - network_backward(net, tgt_shape, tgt_ytype, tgt_lh); + network_backward(net, target); /* Adam step */ adam_t++; network_adam_step(net, lr, beta1, beta2, eps, adam_t); tensor_free(input); - tensor_free(tgt_shape); - tensor_free(tgt_ytype); - tensor_free(tgt_lh); + tensor_free(target); } train_loss /= (float)n_batches; @@ -295,29 +285,23 @@ int train_model(void) { int ishape[] = {bs, 1, 20, 15}; Tensor *input = tensor_alloc(4, ishape); - int sshape[] = {bs, 10}; - Tensor *tgt_shape = tensor_alloc(2, sshape); - - int yshape[] = {bs, 1}; - Tensor *tgt_ytype = tensor_alloc(2, yshape); - Tensor *tgt_lh = tensor_alloc(2, yshape); + int tshape[] = {bs, 12}; + Tensor *target = tensor_alloc(2, tshape); for (int i = 0; i < bs; i++) { Sample *s = &all_samples[indices[n_train + start + i]]; memcpy(input->data + i * 300, s->input, 300 * sizeof(float)); - memcpy(tgt_shape->data + i * 10, s->shape, 10 * sizeof(float)); - tgt_ytype->data[i] = s->ytype; - tgt_lh->data[i] = s->lowheight; + memcpy(target->data + i * 12, s->shape, 10 * sizeof(float)); + target->data[i * 12 + 10] = s->ytype; + target->data[i * 12 + 11] = s->lowheight; } network_forward(net, input, 0); - val_loss += network_bce_loss(net, tgt_shape, tgt_ytype, tgt_lh); + val_loss += network_bce_loss(net, target); val_batches++; tensor_free(input); - tensor_free(tgt_shape); - tensor_free(tgt_ytype); - tensor_free(tgt_lh); + tensor_free(target); } val_loss /= (float)val_batches; diff --git a/Autokem/train_torch.py b/Autokem/train_torch.py new file mode 100644 index 0000000..6b1f9c0 --- /dev/null +++ b/Autokem/train_torch.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 +""" +PyTorch training script for Autokem — drop-in replacement for `autokem train`. + +Reads the same *_variable.tga sprite sheets, trains the same architecture, +and saves weights in safetensors format loadable by the C inference code. + +Usage: + python train_keras.py # train with defaults + python train_keras.py --epochs 300 # override max epochs + python train_keras.py --lr 0.0005 # override learning rate + python train_keras.py --save model.safetensors + +Requirements: + pip install torch numpy +""" + +import argparse +import json +import os +import struct +import sys +from pathlib import Path + +import numpy as np + +# ---- TGA reader (matches OTFbuild/tga_reader.py and Autokem/tga.c) ---- + +class TgaImage: + __slots__ = ('width', 'height', 'pixels') + + def __init__(self, width, height, pixels): + self.width = width + self.height = height + self.pixels = pixels # flat list of RGBA8888 ints + + def get_pixel(self, x, y): + if x < 0 or x >= self.width or y < 0 or y >= self.height: + return 0 + return self.pixels[y * self.width + x] + + +def read_tga(path): + with open(path, 'rb') as f: + data = f.read() + + pos = 0 + id_length = data[pos]; pos += 1 + _colour_map_type = data[pos]; pos += 1 + image_type = data[pos]; pos += 1 + pos += 5 # colour map spec + pos += 2 # x_origin + pos += 2 # y_origin + width = struct.unpack_from('> 8) & 0xFFFFFF + is_low_height = 1.0 if (img.get_pixel(tag_x, tag_y + 5) & 0xFF) != 0 else 0.0 + + # Shape bits: A(7) B(6) C(5) D(4) E(3) F(2) G(1) H(0) J(15) K(14) + shape = [ + float((kerning_mask >> 7) & 1), # A + float((kerning_mask >> 6) & 1), # B + float((kerning_mask >> 5) & 1), # C + float((kerning_mask >> 4) & 1), # D + float((kerning_mask >> 3) & 1), # E + float((kerning_mask >> 2) & 1), # F + float((kerning_mask >> 1) & 1), # G + float((kerning_mask >> 0) & 1), # H + float((kerning_mask >> 15) & 1), # J + float((kerning_mask >> 14) & 1), # K + ] + + # 15x20 binary input + inp = np.zeros((20, 15), dtype=np.float32) + for gy in range(20): + for gx in range(15): + p = img.get_pixel(cell_x + gx, cell_y + gy) + if (p & 0x80) != 0: + inp[gy, gx] = 1.0 + + inputs.append(inp) + labels.append(shape + [is_kern_ytype, is_low_height]) + + return inputs, labels + + +def collect_all_samples(assets_dir): + """Scan assets_dir for *_variable.tga, collect all labelled samples.""" + all_inputs = [] + all_labels = [] + file_count = 0 + + for name in sorted(os.listdir(assets_dir)): + if not name.endswith('_variable.tga'): + continue + if 'extrawide' in name: + continue + + is_xyswap = 'xyswap' in name + path = os.path.join(assets_dir, name) + inputs, labels = collect_from_sheet(path, is_xyswap) + if inputs: + print(f" {name}: {len(inputs)} samples") + all_inputs.extend(inputs) + all_labels.extend(labels) + file_count += 1 + + return np.array(all_inputs), np.array(all_labels, dtype=np.float32), file_count + + +# ---- Model (matches Autokem/nn.c architecture) ---- + +def build_model(): + """ + Conv2D(1->32, 7x7, padding=1) -> SiLU + Conv2D(32->64, 7x7, padding=1) -> SiLU + GlobalAveragePooling2D -> [64] + Dense(256) -> SiLU + Dense(12) -> sigmoid + """ + import torch + import torch.nn as nn + + class Keminet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, 7, padding=1) + self.conv2 = nn.Conv2d(32, 64, 7, padding=1) + self.fc1 = nn.Linear(64, 256) + # self.fc2 = nn.Linear(256, 48) + self.output = nn.Linear(256, 12) + self.tf = nn.SiLU() + + # He init + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.kaiming_normal_(m.weight, a=0.01, nonlinearity='leaky_relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + x = self.tf(self.conv1(x)) + x = self.tf(self.conv2(x)) + x = x.mean(dim=(2, 3)) # global average pool + x = self.tf(self.fc1(x)) + # x = self.tf(self.fc2(x)) + x = torch.sigmoid(self.output(x)) + return x + + return Keminet() + + +# ---- Safetensors export (matches Autokem/safetensor.c layout) ---- + +def export_safetensors(model, path, total_samples, epochs, val_loss): + """ + Save model weights in safetensors format compatible with the C code. + + C code expects these tensor names with these shapes: + conv1.weight [out_ch, in_ch, kh, kw] — PyTorch matches this layout + conv1.bias [out_ch] + conv2.weight [out_ch, in_ch, kh, kw] + conv2.bias [out_ch] + fc1.weight [out_features, in_features] — PyTorch matches this layout + fc1.bias [out_features] + fc2.weight [out_features, in_features] + fc2.bias [out_features] + output.weight [out_features, in_features] + output.bias [out_features] + """ + tensor_names = [ + 'conv1.weight', 'conv1.bias', + 'conv2.weight', 'conv2.bias', + 'fc1.weight', 'fc1.bias', + # 'fc2.weight', 'fc2.bias', + 'output.weight', 'output.bias', + ] + + state = model.state_dict() + + header = {} + header['__metadata__'] = { + 'samples': str(total_samples), + 'epochs': str(epochs), + 'val_loss': f'{val_loss:.6f}', + } + + data_parts = [] + offset = 0 + for name in tensor_names: + arr = state[name].detach().cpu().numpy().astype(np.float32) + raw = arr.tobytes() + header[name] = { + 'dtype': 'F32', + 'shape': list(arr.shape), + 'data_offsets': [offset, offset + len(raw)], + } + data_parts.append(raw) + offset += len(raw) + + header_json = json.dumps(header, separators=(',', ':')).encode('utf-8') + padded_len = (len(header_json) + 7) & ~7 + header_json = header_json + b' ' * (padded_len - len(header_json)) + + with open(path, 'wb') as f: + f.write(struct.pack('= 0.5).astype(int) + tgt_bits = y_np.astype(int) + + n_val = len(y_np) + n_examples = 0 + + print("\nGlyph Tags — validation predictions:") + for i in range(n_val): + mismatch = not np.array_equal(pred_bits[i], tgt_bits[i]) + if n_examples < max_examples and (mismatch or i < 4): + actual_tag = format_tag(tgt_bits[i]) + pred_tag = format_tag(pred_bits[i]) + status = 'MISMATCH' if mismatch else 'ok' + print(f" actual={actual_tag:<20s} pred={pred_tag:<20s} {status}") + n_examples += 1 + + correct = (pred_bits == tgt_bits) + per_bit = correct.sum(axis=0) + total_correct = correct.sum() + + print(f"\nPer-bit accuracy ({n_val} val samples):") + parts = [f'{BIT_NAMES[b]}:{100*per_bit[b]/n_val:.1f}%' for b in range(12)] + print(f" {' '.join(parts)}") + print(f" Overall: {total_correct}/{n_val*12} ({100*total_correct/(n_val*12):.2f}%)") + + +# ---- Main ---- + +def main(): + parser = argparse.ArgumentParser(description='Train Autokem model (PyTorch)') + parser.add_argument('--assets', default='../src/assets', + help='Path to assets directory (default: ../src/assets)') + parser.add_argument('--save', default='autokem.safetensors', + help='Output safetensors path (default: autokem.safetensors)') + parser.add_argument('--load', default=None, + help='Load weights from safetensors before training') + parser.add_argument('--epochs', type=int, default=200, help='Max epochs (default: 200)') + parser.add_argument('--batch-size', type=int, default=32, help='Batch size (default: 32)') + parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)') + parser.add_argument('--patience', type=int, default=10, + help='Early stopping patience (default: 10)') + parser.add_argument('--val-split', type=float, default=0.2, + help='Validation split (default: 0.2)') + args = parser.parse_args() + + import torch + import torch.nn as nn + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Device: {device}") + + # Collect data + print("Collecting samples...") + X, y, file_count = collect_all_samples(args.assets) + + if len(X) < 10: + print(f"Error: too few samples ({len(X)})", file=sys.stderr) + return 1 + + total = len(X) + print(f"Collected {total} samples from {file_count} sheets") + print_label_distribution(y, total) + + nonzero = np.any(X.reshape(total, -1) > 0.5, axis=1).sum() + print(f" Non-empty inputs: {nonzero}/{total}\n") + + # Shuffle and split + rng = np.random.default_rng(42) + perm = rng.permutation(total) + X, y = X[perm], y[perm] + + n_train = int(total * (1 - args.val_split)) + X_train, X_val = X[:n_train], X[n_train:] + y_train, y_val = y[:n_train], y[n_train:] + print(f"Train: {n_train}, Validation: {total - n_train}\n") + + # Convert to tensors — PyTorch conv expects [N, C, H, W] + X_train_t = torch.from_numpy(X_train[:, np.newaxis, :, :]).to(device) # [N,1,20,15] + y_train_t = torch.from_numpy(y_train).to(device) + X_val_t = torch.from_numpy(X_val[:, np.newaxis, :, :]).to(device) + y_val_t = torch.from_numpy(y_val).to(device) + + # Build model + model = build_model().to(device) + + if args.load: + load_safetensors(model, args.load) + + total_params = sum(p.numel() for p in model.parameters()) + print(f"Model parameters: {total_params} ({total_params * 4 / 1024:.1f} KB)\n") + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + loss_fn = nn.BCELoss() + + best_val_loss = float('inf') + best_epoch = 0 + patience_counter = 0 + best_state = None + + for epoch in range(1, args.epochs + 1): + # Training + model.train() + perm_train = torch.randperm(n_train, device=device) + train_loss = 0.0 + n_batches = 0 + + for start in range(0, n_train, args.batch_size): + end = min(start + args.batch_size, n_train) + idx = perm_train[start:end] + + optimizer.zero_grad() + pred = model(X_train_t[idx]) + loss = loss_fn(pred, y_train_t[idx]) + loss.backward() + optimizer.step() + + train_loss += loss.item() + n_batches += 1 + + train_loss /= n_batches + + # Validation + model.eval() + with torch.no_grad(): + val_pred = model(X_val_t) + val_loss = loss_fn(val_pred, y_val_t).item() + + marker = '' + if val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + patience_counter = 0 + best_state = {k: v.clone() for k, v in model.state_dict().items()} + marker = ' *best*' + else: + patience_counter += 1 + + print(f"Epoch {epoch:3d}: train_loss={train_loss:.4f} val_loss={val_loss:.4f}{marker}") + + if patience_counter >= args.patience: + print(f"\nEarly stopping at epoch {epoch} (best epoch: {best_epoch})") + break + + # Restore best weights + if best_state is not None: + model.load_state_dict(best_state) + + print(f"\nBest epoch: {best_epoch}, val_loss: {best_val_loss:.6f}") + + # Print accuracy + model.eval() + print_examples_and_accuracy(model, X_val_t, y_val, max_examples=8) + + # Save + export_safetensors(model, args.save, total, best_epoch, best_val_loss) + + return 0 + + +if __name__ == '__main__': + sys.exit(main())