mirror of
https://github.com/curioustorvald/Terrarum-sans-bitmap.git
synced 2026-06-11 08:24:04 +09:00
revised autokem model
This commit is contained in:
@@ -188,7 +188,7 @@ def build_model():
|
||||
class Keminet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 7, padding=1)
|
||||
self.conv1 = nn.Conv2d(1, 32, 5, padding=1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 7, padding=1)
|
||||
self.fc1 = nn.Linear(64, 256)
|
||||
# self.fc2 = nn.Linear(256, 48)
|
||||
@@ -306,6 +306,7 @@ def load_safetensors(model, path):
|
||||
|
||||
BIT_NAMES = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'Ytype', 'LowH']
|
||||
SHAPE_CHARS = 'ABCDEFGHJK'
|
||||
MIRROR_PAIRS = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)] # A↔B, C↔D, E↔F, G↔H, J↔K
|
||||
|
||||
|
||||
def format_tag(bits_12):
|
||||
@@ -359,6 +360,180 @@ def print_examples_and_accuracy(model, X_val, y_val, max_examples=8):
|
||||
print(f" Overall: {total_correct}/{n_val*12} ({100*total_correct/(n_val*12):.2f}%)")
|
||||
|
||||
|
||||
# ---- Data augmentation ----
|
||||
|
||||
def _shape_key(label):
|
||||
"""10-bit shape tuple from label (A through K)."""
|
||||
return tuple(int(label[i]) for i in range(10))
|
||||
|
||||
|
||||
def _mirror_shape(key):
|
||||
"""Swap mirror pairs: A↔B, C↔D, E↔F, G↔H, J↔K."""
|
||||
m = list(key)
|
||||
for a, b in MIRROR_PAIRS:
|
||||
m[a], m[b] = m[b], m[a]
|
||||
return tuple(m)
|
||||
|
||||
|
||||
def _mirror_label(label):
|
||||
"""Mirror shape bits in label, keep ytype and lowheight."""
|
||||
m = label.copy()
|
||||
for a, b in MIRROR_PAIRS:
|
||||
m[a], m[b] = m[b], m[a]
|
||||
return m
|
||||
|
||||
|
||||
def _shift_image(img, dx, dy):
|
||||
"""Shift 2D image by (dx, dy), fill with 0."""
|
||||
h, w = img.shape
|
||||
shifted = np.zeros_like(img)
|
||||
sx0, sx1 = max(0, -dx), min(w, w - dx)
|
||||
sy0, sy1 = max(0, -dy), min(h, h - dy)
|
||||
dx0, dx1 = max(0, dx), min(w, w + dx)
|
||||
dy0, dy1 = max(0, dy), min(h, h + dy)
|
||||
shifted[dy0:dy1, dx0:dx1] = img[sy0:sy1, sx0:sx1]
|
||||
return shifted
|
||||
|
||||
|
||||
def _augment_one(img, label, rng):
|
||||
"""One augmented copy: random 1px shift + 1% pixel dropout."""
|
||||
dx = rng.integers(-1, 2) # -1, 0, or 1
|
||||
dy = 0
|
||||
aug = _shift_image(img, dx, dy)
|
||||
# mask = rng.random(aug.shape) > 0.01
|
||||
# aug = aug * mask
|
||||
return aug, label.copy()
|
||||
|
||||
|
||||
def _do_mirror_augmentation(X, y, rng):
|
||||
"""For each mirror pair (S, mirror(S)), fill deficit from the common side."""
|
||||
shape_counts = {}
|
||||
shape_indices = {}
|
||||
for i in range(len(y)):
|
||||
key = _shape_key(y[i])
|
||||
shape_counts[key] = shape_counts.get(key, 0) + 1
|
||||
shape_indices.setdefault(key, []).append(i)
|
||||
|
||||
new_X, new_y = [], []
|
||||
done = set() # avoid processing both directions
|
||||
|
||||
for key, count in shape_counts.items():
|
||||
if key in done:
|
||||
continue
|
||||
mkey = _mirror_shape(key)
|
||||
done.add(key)
|
||||
done.add(mkey)
|
||||
if mkey == key:
|
||||
continue # symmetric shape
|
||||
mcount = shape_counts.get(mkey, 0)
|
||||
if count == mcount:
|
||||
continue
|
||||
# Mirror from the larger side to fill the smaller side
|
||||
if count > mcount:
|
||||
src_key, deficit = key, count - mcount
|
||||
else:
|
||||
src_key, deficit = mkey, mcount - count
|
||||
indices = shape_indices.get(src_key, [])
|
||||
if not indices:
|
||||
continue
|
||||
chosen = rng.choice(indices, size=deficit, replace=True)
|
||||
for idx in chosen:
|
||||
new_X.append(np.fliplr(X[idx]).copy())
|
||||
new_y.append(_mirror_label(y[idx]))
|
||||
|
||||
if new_X:
|
||||
X = np.concatenate([X, np.array(new_X)])
|
||||
y = np.concatenate([y, np.array(new_y)])
|
||||
return X, y
|
||||
|
||||
|
||||
def _compute_rarity_weights(y):
|
||||
"""Per-sample weight: sum of inverse bit frequencies for all 12 bits.
|
||||
|
||||
Samples with rare bit values (e.g. J=1 at 13%, C=0 at 8%) get higher weight.
|
||||
"""
|
||||
bit_freq = y.mean(axis=0) # [12], P(bit=1)
|
||||
weights = np.zeros(len(y))
|
||||
for i in range(len(y)):
|
||||
w = 0.0
|
||||
for b in range(12):
|
||||
p = bit_freq[b] if y[i, b] > 0.5 else (1.0 - bit_freq[b])
|
||||
w += 1.0 / max(p, 0.01)
|
||||
weights[i] = w
|
||||
return weights
|
||||
|
||||
|
||||
def _do_rarity_augmentation(X, y, rng, target_new):
|
||||
"""Create target_new augmented samples, drawn proportionally to rarity weight."""
|
||||
if target_new <= 0:
|
||||
return X, y
|
||||
|
||||
weights = _compute_rarity_weights(y)
|
||||
weights /= weights.sum()
|
||||
|
||||
chosen = rng.choice(len(X), size=target_new, replace=True, p=weights)
|
||||
|
||||
new_X, new_y = [], []
|
||||
for idx in chosen:
|
||||
aug_img, aug_label = _augment_one(X[idx], y[idx], rng)
|
||||
new_X.append(aug_img)
|
||||
new_y.append(aug_label)
|
||||
|
||||
X = np.concatenate([X, np.array(new_X)])
|
||||
y = np.concatenate([y, np.array(new_y)])
|
||||
return X, y
|
||||
|
||||
|
||||
def _print_bit_freq(y, label):
|
||||
"""Print per-bit frequencies for diagnostics."""
|
||||
freq = y.mean(axis=0)
|
||||
names = BIT_NAMES
|
||||
parts = [f'{names[b]}:{freq[b]*100:.0f}%' for b in range(12)]
|
||||
print(f" {label}: {' '.join(parts)}")
|
||||
|
||||
|
||||
def augment_training_data(X_train, y_train, rng, aug_factor=3.0):
|
||||
"""
|
||||
Three-phase data augmentation:
|
||||
1. Mirror augmentation — fill deficit between mirror-paired shapes
|
||||
2. Rarity-weighted — samples with rare bit values get more copies (shift+dropout)
|
||||
3. Y-type boost — repeat phases 1-2 scoped to Y-type samples only
|
||||
"""
|
||||
n0 = len(X_train)
|
||||
_print_bit_freq(y_train, 'Before')
|
||||
|
||||
# Phase 1: Mirror augmentation
|
||||
X_train, y_train = _do_mirror_augmentation(X_train, y_train, rng)
|
||||
n1 = len(X_train)
|
||||
|
||||
# Phase 2: Rarity-weighted augmentation — target aug_factor × original size
|
||||
target_new = int(n0 * aug_factor) - n1
|
||||
X_train, y_train = _do_rarity_augmentation(X_train, y_train, rng, target_new)
|
||||
n2 = len(X_train)
|
||||
|
||||
# Phase 3: Y-type boost — same pipeline for Y-type subset only
|
||||
ytype_mask = y_train[:, 10] > 0.5
|
||||
n_ytype_existing = int(ytype_mask.sum())
|
||||
if n_ytype_existing > 0:
|
||||
X_yt = X_train[ytype_mask]
|
||||
y_yt = y_train[ytype_mask]
|
||||
X_yt, y_yt = _do_mirror_augmentation(X_yt, y_yt, rng)
|
||||
# Double the Y-type subset via rarity augmentation
|
||||
yt_new = n_ytype_existing
|
||||
X_yt, y_yt = _do_rarity_augmentation(X_yt, y_yt, rng, yt_new)
|
||||
if len(X_yt) > n_ytype_existing:
|
||||
X_train = np.concatenate([X_train, X_yt[n_ytype_existing:]])
|
||||
y_train = np.concatenate([y_train, y_yt[n_ytype_existing:]])
|
||||
|
||||
n3 = len(X_train)
|
||||
|
||||
_print_bit_freq(y_train, 'After ')
|
||||
print(f"Data augmentation: {n0} → {n3} samples ({n3/n0:.1f}×)")
|
||||
print(f" Mirror: +{n1 - n0}, Rarity: +{n2 - n1}, Y-type boost: +{n3 - n2}")
|
||||
|
||||
return X_train, y_train
|
||||
|
||||
|
||||
# ---- Main ----
|
||||
|
||||
def main():
|
||||
@@ -376,6 +551,10 @@ def main():
|
||||
help='Early stopping patience (default: 10)')
|
||||
parser.add_argument('--val-split', type=float, default=0.2,
|
||||
help='Validation split (default: 0.2)')
|
||||
parser.add_argument('--no-augment', action='store_true',
|
||||
help='Disable data augmentation')
|
||||
parser.add_argument('--aug-factor', type=float, default=3.0,
|
||||
help='Augmentation target multiplier (default: 3.0)')
|
||||
args = parser.parse_args()
|
||||
|
||||
import torch
|
||||
@@ -404,10 +583,17 @@ def main():
|
||||
perm = rng.permutation(total)
|
||||
X, y = X[perm], y[perm]
|
||||
|
||||
n_train = int(total * (1 - args.val_split))
|
||||
n_val = int(total * args.val_split)
|
||||
n_train = total - n_val
|
||||
X_train, X_val = X[:n_train], X[n_train:]
|
||||
y_train, y_val = y[:n_train], y[n_train:]
|
||||
print(f"Train: {n_train}, Validation: {total - n_train}\n")
|
||||
print(f"Train: {n_train}, Validation: {n_val}")
|
||||
|
||||
# Data augmentation (training set only)
|
||||
if not args.no_augment:
|
||||
X_train, y_train = augment_training_data(X_train, y_train, rng, args.aug_factor)
|
||||
n_train = len(X_train)
|
||||
print()
|
||||
|
||||
# Convert to tensors — PyTorch conv expects [N, C, H, W]
|
||||
X_train_t = torch.from_numpy(X_train[:, np.newaxis, :, :]).to(device) # [N,1,20,15]
|
||||
|
||||
Reference in New Issue
Block a user