mirror of
https://github.com/curioustorvald/Terrarum-sans-bitmap.git
synced 2026-03-11 05:41:51 +09:00
497 lines
16 KiB
Python
497 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
PyTorch training script for Autokem — drop-in replacement for `autokem train`.
|
|
|
|
Reads the same *_variable.tga sprite sheets, trains the same architecture,
|
|
and saves weights in safetensors format loadable by the C inference code.
|
|
|
|
Usage:
|
|
python train_keras.py # train with defaults
|
|
python train_keras.py --epochs 300 # override max epochs
|
|
python train_keras.py --lr 0.0005 # override learning rate
|
|
python train_keras.py --save model.safetensors
|
|
|
|
Requirements:
|
|
pip install torch numpy
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import struct
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
# ---- TGA reader (matches OTFbuild/tga_reader.py and Autokem/tga.c) ----
|
|
|
|
class TgaImage:
|
|
__slots__ = ('width', 'height', 'pixels')
|
|
|
|
def __init__(self, width, height, pixels):
|
|
self.width = width
|
|
self.height = height
|
|
self.pixels = pixels # flat list of RGBA8888 ints
|
|
|
|
def get_pixel(self, x, y):
|
|
if x < 0 or x >= self.width or y < 0 or y >= self.height:
|
|
return 0
|
|
return self.pixels[y * self.width + x]
|
|
|
|
|
|
def read_tga(path):
|
|
with open(path, 'rb') as f:
|
|
data = f.read()
|
|
|
|
pos = 0
|
|
id_length = data[pos]; pos += 1
|
|
_colour_map_type = data[pos]; pos += 1
|
|
image_type = data[pos]; pos += 1
|
|
pos += 5 # colour map spec
|
|
pos += 2 # x_origin
|
|
pos += 2 # y_origin
|
|
width = struct.unpack_from('<H', data, pos)[0]; pos += 2
|
|
height = struct.unpack_from('<H', data, pos)[0]; pos += 2
|
|
bits_per_pixel = data[pos]; pos += 1
|
|
descriptor = data[pos]; pos += 1
|
|
top_to_bottom = (descriptor & 0x20) != 0
|
|
bpp = bits_per_pixel // 8
|
|
pos += id_length
|
|
|
|
if image_type != 2 or bpp not in (3, 4):
|
|
raise ValueError(f"Unsupported TGA: type={image_type}, bpp={bits_per_pixel}")
|
|
|
|
pixels = [0] * (width * height)
|
|
for row in range(height):
|
|
y = row if top_to_bottom else (height - 1 - row)
|
|
for x in range(width):
|
|
b = data[pos]; g = data[pos+1]; r = data[pos+2]
|
|
a = data[pos+3] if bpp == 4 else 0xFF
|
|
pos += bpp
|
|
pixels[y * width + x] = (r << 24) | (g << 16) | (b << 8) | a
|
|
|
|
return TgaImage(width, height, pixels)
|
|
|
|
|
|
def tagify(pixel):
|
|
return 0 if (pixel & 0xFF) == 0 else pixel
|
|
|
|
|
|
# ---- Data collection (matches Autokem/train.c) ----
|
|
|
|
def collect_from_sheet(path, is_xyswap):
|
|
"""Extract labelled samples from a single TGA sheet."""
|
|
img = read_tga(path)
|
|
cell_w, cell_h = 16, 20
|
|
cols = img.width // cell_w
|
|
rows = img.height // cell_h
|
|
total_cells = cols * rows
|
|
|
|
inputs = []
|
|
labels = []
|
|
|
|
for index in range(total_cells):
|
|
if is_xyswap:
|
|
cell_x = (index // cols) * cell_w
|
|
cell_y = (index % cols) * cell_h
|
|
else:
|
|
cell_x = (index % cols) * cell_w
|
|
cell_y = (index // cols) * cell_h
|
|
|
|
tag_x = cell_x + (cell_w - 1)
|
|
tag_y = cell_y
|
|
|
|
# Width (5-bit)
|
|
width = 0
|
|
for y in range(5):
|
|
if img.get_pixel(tag_x, tag_y + y) & 0xFF:
|
|
width |= (1 << y)
|
|
if width == 0:
|
|
continue
|
|
|
|
# Kern data pixel at Y+6
|
|
kern_pixel = tagify(img.get_pixel(tag_x, tag_y + 6))
|
|
if (kern_pixel & 0xFF) == 0:
|
|
continue # no kern data
|
|
|
|
# Extract labels
|
|
is_kern_ytype = 1.0 if (kern_pixel & 0x80000000) != 0 else 0.0
|
|
kerning_mask = (kern_pixel >> 8) & 0xFFFFFF
|
|
is_low_height = 1.0 if (img.get_pixel(tag_x, tag_y + 5) & 0xFF) != 0 else 0.0
|
|
|
|
# Shape bits: A(7) B(6) C(5) D(4) E(3) F(2) G(1) H(0) J(15) K(14)
|
|
shape = [
|
|
float((kerning_mask >> 7) & 1), # A
|
|
float((kerning_mask >> 6) & 1), # B
|
|
float((kerning_mask >> 5) & 1), # C
|
|
float((kerning_mask >> 4) & 1), # D
|
|
float((kerning_mask >> 3) & 1), # E
|
|
float((kerning_mask >> 2) & 1), # F
|
|
float((kerning_mask >> 1) & 1), # G
|
|
float((kerning_mask >> 0) & 1), # H
|
|
float((kerning_mask >> 15) & 1), # J
|
|
float((kerning_mask >> 14) & 1), # K
|
|
]
|
|
|
|
# 15x20 binary input
|
|
inp = np.zeros((20, 15), dtype=np.float32)
|
|
for gy in range(20):
|
|
for gx in range(15):
|
|
p = img.get_pixel(cell_x + gx, cell_y + gy)
|
|
if (p & 0x80) != 0:
|
|
inp[gy, gx] = 1.0
|
|
|
|
inputs.append(inp)
|
|
labels.append(shape + [is_kern_ytype, is_low_height])
|
|
|
|
return inputs, labels
|
|
|
|
|
|
def collect_all_samples(assets_dir):
|
|
"""Scan assets_dir for *_variable.tga, collect all labelled samples."""
|
|
all_inputs = []
|
|
all_labels = []
|
|
file_count = 0
|
|
|
|
for name in sorted(os.listdir(assets_dir)):
|
|
if not name.endswith('_variable.tga'):
|
|
continue
|
|
if 'extrawide' in name:
|
|
continue
|
|
|
|
is_xyswap = 'xyswap' in name
|
|
path = os.path.join(assets_dir, name)
|
|
inputs, labels = collect_from_sheet(path, is_xyswap)
|
|
if inputs:
|
|
print(f" {name}: {len(inputs)} samples")
|
|
all_inputs.extend(inputs)
|
|
all_labels.extend(labels)
|
|
file_count += 1
|
|
|
|
return np.array(all_inputs), np.array(all_labels, dtype=np.float32), file_count
|
|
|
|
|
|
# ---- Model (matches Autokem/nn.c architecture) ----
|
|
|
|
def build_model():
|
|
"""
|
|
Conv2D(1->32, 7x7, padding=1) -> SiLU
|
|
Conv2D(32->64, 7x7, padding=1) -> SiLU
|
|
GlobalAveragePooling2D -> [64]
|
|
Dense(256) -> SiLU
|
|
Dense(12) -> sigmoid
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class Keminet(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(1, 32, 7, padding=1)
|
|
self.conv2 = nn.Conv2d(32, 64, 7, padding=1)
|
|
self.fc1 = nn.Linear(64, 256)
|
|
# self.fc2 = nn.Linear(256, 48)
|
|
self.output = nn.Linear(256, 12)
|
|
self.tf = nn.SiLU()
|
|
|
|
# He init
|
|
for m in self.modules():
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
nn.init.kaiming_normal_(m.weight, a=0.01, nonlinearity='leaky_relu')
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
x = self.tf(self.conv1(x))
|
|
x = self.tf(self.conv2(x))
|
|
x = x.mean(dim=(2, 3)) # global average pool
|
|
x = self.tf(self.fc1(x))
|
|
# x = self.tf(self.fc2(x))
|
|
x = torch.sigmoid(self.output(x))
|
|
return x
|
|
|
|
return Keminet()
|
|
|
|
|
|
# ---- Safetensors export (matches Autokem/safetensor.c layout) ----
|
|
|
|
def export_safetensors(model, path, total_samples, epochs, val_loss):
|
|
"""
|
|
Save model weights in safetensors format compatible with the C code.
|
|
|
|
C code expects these tensor names with these shapes:
|
|
conv1.weight [out_ch, in_ch, kh, kw] — PyTorch matches this layout
|
|
conv1.bias [out_ch]
|
|
conv2.weight [out_ch, in_ch, kh, kw]
|
|
conv2.bias [out_ch]
|
|
fc1.weight [out_features, in_features] — PyTorch matches this layout
|
|
fc1.bias [out_features]
|
|
fc2.weight [out_features, in_features]
|
|
fc2.bias [out_features]
|
|
output.weight [out_features, in_features]
|
|
output.bias [out_features]
|
|
"""
|
|
tensor_names = [
|
|
'conv1.weight', 'conv1.bias',
|
|
'conv2.weight', 'conv2.bias',
|
|
'fc1.weight', 'fc1.bias',
|
|
# 'fc2.weight', 'fc2.bias',
|
|
'output.weight', 'output.bias',
|
|
]
|
|
|
|
state = model.state_dict()
|
|
|
|
header = {}
|
|
header['__metadata__'] = {
|
|
'samples': str(total_samples),
|
|
'epochs': str(epochs),
|
|
'val_loss': f'{val_loss:.6f}',
|
|
}
|
|
|
|
data_parts = []
|
|
offset = 0
|
|
for name in tensor_names:
|
|
arr = state[name].detach().cpu().numpy().astype(np.float32)
|
|
raw = arr.tobytes()
|
|
header[name] = {
|
|
'dtype': 'F32',
|
|
'shape': list(arr.shape),
|
|
'data_offsets': [offset, offset + len(raw)],
|
|
}
|
|
data_parts.append(raw)
|
|
offset += len(raw)
|
|
|
|
header_json = json.dumps(header, separators=(',', ':')).encode('utf-8')
|
|
padded_len = (len(header_json) + 7) & ~7
|
|
header_json = header_json + b' ' * (padded_len - len(header_json))
|
|
|
|
with open(path, 'wb') as f:
|
|
f.write(struct.pack('<Q', len(header_json)))
|
|
f.write(header_json)
|
|
for part in data_parts:
|
|
f.write(part)
|
|
|
|
total_bytes = 8 + len(header_json) + offset
|
|
print(f"Saved model to {path} ({total_bytes} bytes)")
|
|
|
|
|
|
def load_safetensors(model, path):
|
|
"""Load weights from safetensors file into the PyTorch model."""
|
|
import torch
|
|
|
|
with open(path, 'rb') as f:
|
|
header_len = struct.unpack('<Q', f.read(8))[0]
|
|
header_json = f.read(header_len)
|
|
header = json.loads(header_json)
|
|
data_start = 8 + header_len
|
|
|
|
state = model.state_dict()
|
|
for name in state:
|
|
if name not in header:
|
|
print(f" Warning: tensor '{name}' not in safetensors")
|
|
continue
|
|
entry = header[name]
|
|
off_start, off_end = entry['data_offsets']
|
|
f.seek(data_start + off_start)
|
|
raw = f.read(off_end - off_start)
|
|
arr = np.frombuffer(raw, dtype=np.float32).reshape(entry['shape'])
|
|
state[name] = torch.from_numpy(arr.copy())
|
|
|
|
model.load_state_dict(state)
|
|
print(f"Loaded weights from {path}")
|
|
|
|
|
|
# ---- Pretty-print helpers ----
|
|
|
|
BIT_NAMES = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'Ytype', 'LowH']
|
|
SHAPE_CHARS = 'ABCDEFGHJK'
|
|
|
|
|
|
def format_tag(bits_12):
|
|
"""Format 12 binary bits as keming_machine tag string, e.g. 'ABCDEFGH(B)'."""
|
|
chars = ''.join(SHAPE_CHARS[i] for i in range(10) if bits_12[i])
|
|
if not chars:
|
|
chars = '(empty)'
|
|
mode = '(Y)' if bits_12[10] else '(B)'
|
|
low = ' low' if bits_12[11] else ''
|
|
return f'{chars}{mode}{low}'
|
|
|
|
|
|
def print_label_distribution(labels, total):
|
|
counts = labels.sum(axis=0).astype(int)
|
|
parts = [f'{BIT_NAMES[b]}:{counts[b]}({100*counts[b]/total:.0f}%)' for b in range(12)]
|
|
print(f"Label distribution:\n {' '.join(parts)}")
|
|
|
|
|
|
def print_examples_and_accuracy(model, X_val, y_val, max_examples=8):
|
|
"""Print example predictions and per-bit accuracy on validation set."""
|
|
import torch
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
preds = model(X_val).cpu().numpy()
|
|
|
|
y_np = y_val.cpu().numpy() if hasattr(y_val, 'cpu') else y_val
|
|
pred_bits = (preds >= 0.5).astype(int)
|
|
tgt_bits = y_np.astype(int)
|
|
|
|
n_val = len(y_np)
|
|
n_examples = 0
|
|
|
|
print("\nGlyph Tags — validation predictions:")
|
|
for i in range(n_val):
|
|
mismatch = not np.array_equal(pred_bits[i], tgt_bits[i])
|
|
if n_examples < max_examples and (mismatch or i < 4):
|
|
actual_tag = format_tag(tgt_bits[i])
|
|
pred_tag = format_tag(pred_bits[i])
|
|
status = 'MISMATCH' if mismatch else 'ok'
|
|
print(f" actual={actual_tag:<20s} pred={pred_tag:<20s} {status}")
|
|
n_examples += 1
|
|
|
|
correct = (pred_bits == tgt_bits)
|
|
per_bit = correct.sum(axis=0)
|
|
total_correct = correct.sum()
|
|
|
|
print(f"\nPer-bit accuracy ({n_val} val samples):")
|
|
parts = [f'{BIT_NAMES[b]}:{100*per_bit[b]/n_val:.1f}%' for b in range(12)]
|
|
print(f" {' '.join(parts)}")
|
|
print(f" Overall: {total_correct}/{n_val*12} ({100*total_correct/(n_val*12):.2f}%)")
|
|
|
|
|
|
# ---- Main ----
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train Autokem model (PyTorch)')
|
|
parser.add_argument('--assets', default='../src/assets',
|
|
help='Path to assets directory (default: ../src/assets)')
|
|
parser.add_argument('--save', default='autokem.safetensors',
|
|
help='Output safetensors path (default: autokem.safetensors)')
|
|
parser.add_argument('--load', default=None,
|
|
help='Load weights from safetensors before training')
|
|
parser.add_argument('--epochs', type=int, default=200, help='Max epochs (default: 200)')
|
|
parser.add_argument('--batch-size', type=int, default=32, help='Batch size (default: 32)')
|
|
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
|
|
parser.add_argument('--patience', type=int, default=10,
|
|
help='Early stopping patience (default: 10)')
|
|
parser.add_argument('--val-split', type=float, default=0.2,
|
|
help='Validation split (default: 0.2)')
|
|
args = parser.parse_args()
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f"Device: {device}")
|
|
|
|
# Collect data
|
|
print("Collecting samples...")
|
|
X, y, file_count = collect_all_samples(args.assets)
|
|
|
|
if len(X) < 10:
|
|
print(f"Error: too few samples ({len(X)})", file=sys.stderr)
|
|
return 1
|
|
|
|
total = len(X)
|
|
print(f"Collected {total} samples from {file_count} sheets")
|
|
print_label_distribution(y, total)
|
|
|
|
nonzero = np.any(X.reshape(total, -1) > 0.5, axis=1).sum()
|
|
print(f" Non-empty inputs: {nonzero}/{total}\n")
|
|
|
|
# Shuffle and split
|
|
rng = np.random.default_rng(42)
|
|
perm = rng.permutation(total)
|
|
X, y = X[perm], y[perm]
|
|
|
|
n_train = int(total * (1 - args.val_split))
|
|
X_train, X_val = X[:n_train], X[n_train:]
|
|
y_train, y_val = y[:n_train], y[n_train:]
|
|
print(f"Train: {n_train}, Validation: {total - n_train}\n")
|
|
|
|
# Convert to tensors — PyTorch conv expects [N, C, H, W]
|
|
X_train_t = torch.from_numpy(X_train[:, np.newaxis, :, :]).to(device) # [N,1,20,15]
|
|
y_train_t = torch.from_numpy(y_train).to(device)
|
|
X_val_t = torch.from_numpy(X_val[:, np.newaxis, :, :]).to(device)
|
|
y_val_t = torch.from_numpy(y_val).to(device)
|
|
|
|
# Build model
|
|
model = build_model().to(device)
|
|
|
|
if args.load:
|
|
load_safetensors(model, args.load)
|
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
print(f"Model parameters: {total_params} ({total_params * 4 / 1024:.1f} KB)\n")
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
|
loss_fn = nn.BCELoss()
|
|
|
|
best_val_loss = float('inf')
|
|
best_epoch = 0
|
|
patience_counter = 0
|
|
best_state = None
|
|
|
|
for epoch in range(1, args.epochs + 1):
|
|
# Training
|
|
model.train()
|
|
perm_train = torch.randperm(n_train, device=device)
|
|
train_loss = 0.0
|
|
n_batches = 0
|
|
|
|
for start in range(0, n_train, args.batch_size):
|
|
end = min(start + args.batch_size, n_train)
|
|
idx = perm_train[start:end]
|
|
|
|
optimizer.zero_grad()
|
|
pred = model(X_train_t[idx])
|
|
loss = loss_fn(pred, y_train_t[idx])
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
train_loss += loss.item()
|
|
n_batches += 1
|
|
|
|
train_loss /= n_batches
|
|
|
|
# Validation
|
|
model.eval()
|
|
with torch.no_grad():
|
|
val_pred = model(X_val_t)
|
|
val_loss = loss_fn(val_pred, y_val_t).item()
|
|
|
|
marker = ''
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
best_epoch = epoch
|
|
patience_counter = 0
|
|
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
|
marker = ' *best*'
|
|
else:
|
|
patience_counter += 1
|
|
|
|
print(f"Epoch {epoch:3d}: train_loss={train_loss:.4f} val_loss={val_loss:.4f}{marker}")
|
|
|
|
if patience_counter >= args.patience:
|
|
print(f"\nEarly stopping at epoch {epoch} (best epoch: {best_epoch})")
|
|
break
|
|
|
|
# Restore best weights
|
|
if best_state is not None:
|
|
model.load_state_dict(best_state)
|
|
|
|
print(f"\nBest epoch: {best_epoch}, val_loss: {best_val_loss:.6f}")
|
|
|
|
# Print accuracy
|
|
model.eval()
|
|
print_examples_and_accuracy(model, X_val_t, y_val, max_examples=8)
|
|
|
|
# Save
|
|
export_safetensors(model, args.save, total, best_epoch, best_val_loss)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(main())
|