diff --git a/Autokem/train_torch.py b/Autokem/train_torch.py index 5eca787..e583539 100644 --- a/Autokem/train_torch.py +++ b/Autokem/train_torch.py @@ -398,7 +398,7 @@ def _shift_image(img, dx, dy): def _augment_one(img, label, rng): """One augmented copy: random 1px shift + 1% pixel dropout.""" dx = rng.integers(-1, 2) # -1, 0, or 1 - dy = 0 + dy = rng.integers(-1, 2) # -1, 0, or 1 aug = _shift_image(img, dx, dy) # mask = rng.random(aug.shape) > 0.01 # aug = aug * mask