Batch Size
Over 9000!!!
Sei stufo di sentirti sopraffatto da quei 2048 samples di batch size dei paper di Facebook? Ti sei stancato di dover sempre scendere a compromessi tra risoluzione o batch size per allenare il tuo modello sulla tua sudata GPU da 8 GB di RAM? Se questo fosse il tuo caso sei nel posto giusto. È l’ora di dire basta 🙂
# tensorfow
gradients = [tf.zeros_like(var) for var in model.trainable_variables]
for i, (inputs, labels) in enumerate(train_ds):
with tf.GradientTape() as tape:
predictions = model(inputs)
loss = loss_function(predictions, labels)
new_gradients = tape.gradient(loss, model.trainable_variables)
gradients = [(new_grad+grad)/ accumulation_steps for new_grad, grad in zip(new_gradients, gradients)]
if (i+1) % accumulation_steps == 0:
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
gradients = [tf.zeros_like(var) for var in model.trainable_variables]
# pytorch
model.zero_grad()
for i, (inputs, labels) in enumerate(train_ds):
predictions = model(inputs)
loss = loss_function(predictions, labels)
loss = loss / accumulation_steps
loss.backward()
if (i+1) % accumulation_steps == 0:
optimizer.step()
model.zero_grad()
Per quanto riguarda addestramenti più particolari, in cui il numero di sample sono differenti per ogni step, i codici sono i seguenti:
# tensorfow with different samples per step
gradients = [tf.zeros_like(var) for var in model.trainable_variables]
for i, (inputs, labels) in enumerate(train_ds):
with tf.GradientTape() as tape:
predictions = model(inputs)
loss, samples_per_step = loss_function(predictions, labels) # the loss should provide the number of samples per step
new_gradients = tape.gradient(loss, model.trainable_variables)
gradients = [((new_grad+grad)/ tot_samples_batch) * samples_per_step for new_grad, grad in zip(new_gradients, gradients)]
if (i+1) % accumulation_steps == 0:
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
gradients = [tf.zeros_like(var) for var in model.trainable_variables]
# pytorch with different samples per step
model.zero_grad()
for i, (inputs, labels) in enumerate(train_ds):
predictions = model(inputs)
loss, samples_per_step = loss_function(predictions, labels) # the loss should provide the number of samples per step
loss = (loss / tot_samples_batch) * samples_per_step
loss.backward()
if (i+1) % accumulation_steps == 0:
optimizer.step()
model.zero_grad()
Questo si rende necessario per la media aritmetica che fa la birichina. Infatti, se medio solo elementi che hanno tutti lo stesso peso (tipo tutti derivano da una media di due elementi), non incorro in nessun problema. Quindi, banalmente, se ho una lista [1,1,2,2] e medio i primi due elementi e gli ultimi due [(1+1)/2, (2+2)/2]=[1,2] e infine medio i valori rimanenti [(1+2)/2]=1,5, questo è uguale a mediare tutti gli elementi insieme dall’inizio (1+1+2+2)/4=1,5. Questo però non succede se medio elementi con diverso peso: se prima mediassi, (1,1,2)/3=1,33, e poi (1,33+2)/2=sicuramente non 1,5. Questo perchè sto semplicemente mediando cose diverse.
A parte questi discorsi sui massimi sistemi, spero che questo semplice stratagemma possa esservi d’aiuto e che possiate addestrare tutte le reti che vorrete con le dovute risoluzioni e giusti batch.