Batch Size
Over 9000!!!

Sei stufo di sentirti sopraffatto da quei 2048 samples di batch size dei paper di Facebook? Ti sei stancato di dover sempre scendere a compromessi tra risoluzione o batch size per allenare il tuo modello sulla tua sudata GPU da 8 GB di RAM? Se questo fosse il tuo caso sei nel posto giusto. È l’ora di dire basta 🙂

Con questo semplice trucco, che l’umanità si tramanda sin dagli antichi Egizi, potrai facilmente addestrare le tue reti con batch size delle dimensioni che vorrai. E tutto questo senza vedere quel fastidiosissimo errore di “out-of-memory”, che affligge tutti noi comuni mortai. 
 
Il trucco è molto semplice: basta accumulare i gradienti, fare le giuste divisioni e aggiornare i pesi solo dopo un certo numero n di iterazioni. Quindi, invece di aggiornare i parametri ad ogni step della nostra epoca, lo faremo solo dopo n*steps. In questo modo, se la mia GPU per uno specifico modello riesce ad avere un batch size  massimo di 4, aspettando 16 steps posso ottenere un batch size di 64🥳. Quindi, il batch size effettivo sulla GPU rimane di quattro, ma l’effetto dell’update è ugual e matematicamente identico a uno da 64.
 
Questo trucchetto da prestigiatore dei poveri può essere fatto facilmente con tutte le librerie di calcolo scientifico. Nel caso di steps con un numero uguale di sample, si riportano  di seguito i codici in TensorFlow e PyTorch, da inserire nel proprio training loop:
				
					# tensorfow
gradients = [tf.zeros_like(var) for var in model.trainable_variables]
 for i, (inputs, labels) in enumerate(train_ds):
 	with tf.GradientTape() as tape:
 		predictions = model(inputs)
 		loss = loss_function(predictions, labels)  
 	new_gradients = tape.gradient(loss, model.trainable_variables)
 	gradients = [(new_grad+grad)/ accumulation_steps  for new_grad, grad in zip(new_gradients, gradients)]
 		if (i+1) % accumulation_steps == 0:  
 			optimizer.apply_gradients(zip(gradients, model.trainable_variables))
 			gradients = [tf.zeros_like(var) for var in model.trainable_variables]
				
			
				
					# pytorch
model.zero_grad()                                  
for i, (inputs, labels) in enumerate(train_ds):
    predictions = model(inputs)    
    loss = loss_function(predictions, labels)              
    loss = loss / accumulation_steps              
    loss.backward()                                
    if (i+1) % accumulation_steps == 0:            
        optimizer.step()                            
        model.zero_grad()  
				
			

Per quanto riguarda addestramenti più particolari, in cui il numero di sample sono differenti per ogni step, i codici sono i seguenti:

				
					# tensorfow with different samples per step
gradients = [tf.zeros_like(var) for var in model.trainable_variables]
 for i, (inputs, labels) in enumerate(train_ds):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss, samples_per_step  = loss_function(predictions, labels)  # the loss should provide the number of samples per step     
    new_gradients = tape.gradient(loss, model.trainable_variables)
    gradients = [((new_grad+grad)/ tot_samples_batch) * samples_per_step  for new_grad, grad in zip(new_gradients, gradients)]
        if (i+1) % accumulation_steps == 0:  
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            gradients = [tf.zeros_like(var) for var in model.trainable_variables]
				
			
				
					# pytorch with different samples per step
model.zero_grad()                                  
for i, (inputs, labels) in enumerate(train_ds):
    predictions = model(inputs)    
    loss, samples_per_step = loss_function(predictions, labels)   # the loss should provide the number of samples per step          
    loss = (loss / tot_samples_batch) *  samples_per_step           
    loss.backward()                                
    if (i+1) % accumulation_steps == 0:            
        optimizer.step()                            
        model.zero_grad()  
				
			

Questo si rende necessario per la media aritmetica che fa la birichina. Infatti, se medio solo elementi che hanno tutti lo stesso peso (tipo tutti derivano da una media di due elementi), non incorro in nessun problema. Quindi, banalmente, se ho una lista [1,1,2,2] e medio i primi due elementi e gli ultimi due [(1+1)/2, (2+2)/2]=[1,2] e infine medio i valori rimanenti [(1+2)/2]=1,5, questo è uguale a mediare tutti gli elementi insieme dall’inizio (1+1+2+2)/4=1,5. Questo però non succede se medio elementi con diverso peso: se prima mediassi, (1,1,2)/3=1,33, e poi (1,33+2)/2=sicuramente non 1,5. Questo perchè sto semplicemente mediando cose diverse. 

A parte questi discorsi sui massimi sistemi, spero che questo semplice stratagemma possa esservi d’aiuto  e che possiate addestrare tutte le reti che vorrete con le dovute risoluzioni e giusti batch.