revised autokem model

This commit is contained in:
minjaesong
2026-03-09 23:46:28 +09:00
parent 244371aa9d
commit 268610a8b3
10 changed files with 834 additions and 14 deletions

View File

@@ -188,7 +188,7 @@ def build_model():
class Keminet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 7, padding=1)
self.conv1 = nn.Conv2d(1, 32, 5, padding=1)
self.conv2 = nn.Conv2d(32, 64, 7, padding=1)
self.fc1 = nn.Linear(64, 256)
# self.fc2 = nn.Linear(256, 48)
@@ -306,6 +306,7 @@ def load_safetensors(model, path):
BIT_NAMES = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'Ytype', 'LowH']
SHAPE_CHARS = 'ABCDEFGHJK'
MIRROR_PAIRS = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)] # A↔B, C↔D, E↔F, G↔H, J↔K
def format_tag(bits_12):
@@ -359,6 +360,180 @@ def print_examples_and_accuracy(model, X_val, y_val, max_examples=8):
print(f" Overall: {total_correct}/{n_val*12} ({100*total_correct/(n_val*12):.2f}%)")
# ---- Data augmentation ----
def _shape_key(label):
"""10-bit shape tuple from label (A through K)."""
return tuple(int(label[i]) for i in range(10))
def _mirror_shape(key):
"""Swap mirror pairs: A↔B, C↔D, E↔F, G↔H, J↔K."""
m = list(key)
for a, b in MIRROR_PAIRS:
m[a], m[b] = m[b], m[a]
return tuple(m)
def _mirror_label(label):
"""Mirror shape bits in label, keep ytype and lowheight."""
m = label.copy()
for a, b in MIRROR_PAIRS:
m[a], m[b] = m[b], m[a]
return m
def _shift_image(img, dx, dy):
"""Shift 2D image by (dx, dy), fill with 0."""
h, w = img.shape
shifted = np.zeros_like(img)
sx0, sx1 = max(0, -dx), min(w, w - dx)
sy0, sy1 = max(0, -dy), min(h, h - dy)
dx0, dx1 = max(0, dx), min(w, w + dx)
dy0, dy1 = max(0, dy), min(h, h + dy)
shifted[dy0:dy1, dx0:dx1] = img[sy0:sy1, sx0:sx1]
return shifted
def _augment_one(img, label, rng):
"""One augmented copy: random 1px shift + 1% pixel dropout."""
dx = rng.integers(-1, 2) # -1, 0, or 1
dy = 0
aug = _shift_image(img, dx, dy)
# mask = rng.random(aug.shape) > 0.01
# aug = aug * mask
return aug, label.copy()
def _do_mirror_augmentation(X, y, rng):
"""For each mirror pair (S, mirror(S)), fill deficit from the common side."""
shape_counts = {}
shape_indices = {}
for i in range(len(y)):
key = _shape_key(y[i])
shape_counts[key] = shape_counts.get(key, 0) + 1
shape_indices.setdefault(key, []).append(i)
new_X, new_y = [], []
done = set() # avoid processing both directions
for key, count in shape_counts.items():
if key in done:
continue
mkey = _mirror_shape(key)
done.add(key)
done.add(mkey)
if mkey == key:
continue # symmetric shape
mcount = shape_counts.get(mkey, 0)
if count == mcount:
continue
# Mirror from the larger side to fill the smaller side
if count > mcount:
src_key, deficit = key, count - mcount
else:
src_key, deficit = mkey, mcount - count
indices = shape_indices.get(src_key, [])
if not indices:
continue
chosen = rng.choice(indices, size=deficit, replace=True)
for idx in chosen:
new_X.append(np.fliplr(X[idx]).copy())
new_y.append(_mirror_label(y[idx]))
if new_X:
X = np.concatenate([X, np.array(new_X)])
y = np.concatenate([y, np.array(new_y)])
return X, y
def _compute_rarity_weights(y):
"""Per-sample weight: sum of inverse bit frequencies for all 12 bits.
Samples with rare bit values (e.g. J=1 at 13%, C=0 at 8%) get higher weight.
"""
bit_freq = y.mean(axis=0) # [12], P(bit=1)
weights = np.zeros(len(y))
for i in range(len(y)):
w = 0.0
for b in range(12):
p = bit_freq[b] if y[i, b] > 0.5 else (1.0 - bit_freq[b])
w += 1.0 / max(p, 0.01)
weights[i] = w
return weights
def _do_rarity_augmentation(X, y, rng, target_new):
"""Create target_new augmented samples, drawn proportionally to rarity weight."""
if target_new <= 0:
return X, y
weights = _compute_rarity_weights(y)
weights /= weights.sum()
chosen = rng.choice(len(X), size=target_new, replace=True, p=weights)
new_X, new_y = [], []
for idx in chosen:
aug_img, aug_label = _augment_one(X[idx], y[idx], rng)
new_X.append(aug_img)
new_y.append(aug_label)
X = np.concatenate([X, np.array(new_X)])
y = np.concatenate([y, np.array(new_y)])
return X, y
def _print_bit_freq(y, label):
"""Print per-bit frequencies for diagnostics."""
freq = y.mean(axis=0)
names = BIT_NAMES
parts = [f'{names[b]}:{freq[b]*100:.0f}%' for b in range(12)]
print(f" {label}: {' '.join(parts)}")
def augment_training_data(X_train, y_train, rng, aug_factor=3.0):
"""
Three-phase data augmentation:
1. Mirror augmentation — fill deficit between mirror-paired shapes
2. Rarity-weighted — samples with rare bit values get more copies (shift+dropout)
3. Y-type boost — repeat phases 1-2 scoped to Y-type samples only
"""
n0 = len(X_train)
_print_bit_freq(y_train, 'Before')
# Phase 1: Mirror augmentation
X_train, y_train = _do_mirror_augmentation(X_train, y_train, rng)
n1 = len(X_train)
# Phase 2: Rarity-weighted augmentation — target aug_factor × original size
target_new = int(n0 * aug_factor) - n1
X_train, y_train = _do_rarity_augmentation(X_train, y_train, rng, target_new)
n2 = len(X_train)
# Phase 3: Y-type boost — same pipeline for Y-type subset only
ytype_mask = y_train[:, 10] > 0.5
n_ytype_existing = int(ytype_mask.sum())
if n_ytype_existing > 0:
X_yt = X_train[ytype_mask]
y_yt = y_train[ytype_mask]
X_yt, y_yt = _do_mirror_augmentation(X_yt, y_yt, rng)
# Double the Y-type subset via rarity augmentation
yt_new = n_ytype_existing
X_yt, y_yt = _do_rarity_augmentation(X_yt, y_yt, rng, yt_new)
if len(X_yt) > n_ytype_existing:
X_train = np.concatenate([X_train, X_yt[n_ytype_existing:]])
y_train = np.concatenate([y_train, y_yt[n_ytype_existing:]])
n3 = len(X_train)
_print_bit_freq(y_train, 'After ')
print(f"Data augmentation: {n0}{n3} samples ({n3/n0:.1f}×)")
print(f" Mirror: +{n1 - n0}, Rarity: +{n2 - n1}, Y-type boost: +{n3 - n2}")
return X_train, y_train
# ---- Main ----
def main():
@@ -376,6 +551,10 @@ def main():
help='Early stopping patience (default: 10)')
parser.add_argument('--val-split', type=float, default=0.2,
help='Validation split (default: 0.2)')
parser.add_argument('--no-augment', action='store_true',
help='Disable data augmentation')
parser.add_argument('--aug-factor', type=float, default=3.0,
help='Augmentation target multiplier (default: 3.0)')
args = parser.parse_args()
import torch
@@ -404,10 +583,17 @@ def main():
perm = rng.permutation(total)
X, y = X[perm], y[perm]
n_train = int(total * (1 - args.val_split))
n_val = int(total * args.val_split)
n_train = total - n_val
X_train, X_val = X[:n_train], X[n_train:]
y_train, y_val = y[:n_train], y[n_train:]
print(f"Train: {n_train}, Validation: {total - n_train}\n")
print(f"Train: {n_train}, Validation: {n_val}")
# Data augmentation (training set only)
if not args.no_augment:
X_train, y_train = augment_training_data(X_train, y_train, rng, args.aug_factor)
n_train = len(X_train)
print()
# Convert to tensors — PyTorch conv expects [N, C, H, W]
X_train_t = torch.from_numpy(X_train[:, np.newaxis, :, :]).to(device) # [N,1,20,15]