1
0

batchnorm_layer.c 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. #include "convolutional_layer.h"
  2. #include "batchnorm_layer.h"
  3. #include "blas.h"
  4. #include <stdio.h>
  5. layer make_batchnorm_layer(int batch, int w, int h, int c)
  6. {
  7. fprintf(stderr, "Batch Normalization Layer: %d x %d x %d image\n", w,h,c);
  8. layer l = {0};
  9. l.type = BATCHNORM;
  10. l.batch = batch;
  11. l.h = l.out_h = h;
  12. l.w = l.out_w = w;
  13. l.c = l.out_c = c;
  14. l.output = calloc(h * w * c * batch, sizeof(float));
  15. l.delta = calloc(h * w * c * batch, sizeof(float));
  16. l.inputs = w*h*c;
  17. l.outputs = l.inputs;
  18. l.scales = calloc(c, sizeof(float));
  19. l.scale_updates = calloc(c, sizeof(float));
  20. l.biases = calloc(c, sizeof(float));
  21. l.bias_updates = calloc(c, sizeof(float));
  22. int i;
  23. for(i = 0; i < c; ++i){
  24. l.scales[i] = 1;
  25. }
  26. l.mean = calloc(c, sizeof(float));
  27. l.variance = calloc(c, sizeof(float));
  28. l.rolling_mean = calloc(c, sizeof(float));
  29. l.rolling_variance = calloc(c, sizeof(float));
  30. l.forward = forward_batchnorm_layer;
  31. l.backward = backward_batchnorm_layer;
  32. #ifdef GPU
  33. l.forward_gpu = forward_batchnorm_layer_gpu;
  34. l.backward_gpu = backward_batchnorm_layer_gpu;
  35. l.output_gpu = cuda_make_array(l.output, h * w * c * batch);
  36. l.delta_gpu = cuda_make_array(l.delta, h * w * c * batch);
  37. l.biases_gpu = cuda_make_array(l.biases, c);
  38. l.bias_updates_gpu = cuda_make_array(l.bias_updates, c);
  39. l.scales_gpu = cuda_make_array(l.scales, c);
  40. l.scale_updates_gpu = cuda_make_array(l.scale_updates, c);
  41. l.mean_gpu = cuda_make_array(l.mean, c);
  42. l.variance_gpu = cuda_make_array(l.variance, c);
  43. l.rolling_mean_gpu = cuda_make_array(l.mean, c);
  44. l.rolling_variance_gpu = cuda_make_array(l.variance, c);
  45. l.mean_delta_gpu = cuda_make_array(l.mean, c);
  46. l.variance_delta_gpu = cuda_make_array(l.variance, c);
  47. l.x_gpu = cuda_make_array(l.output, l.batch*l.outputs);
  48. l.x_norm_gpu = cuda_make_array(l.output, l.batch*l.outputs);
  49. #ifdef CUDNN
  50. cudnnCreateTensorDescriptor(&l.normTensorDesc);
  51. cudnnCreateTensorDescriptor(&l.dstTensorDesc);
  52. cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
  53. cudnnSetTensor4dDescriptor(l.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l.out_c, 1, 1);
  54. #endif
  55. #endif
  56. return l;
  57. }
  58. void backward_scale_cpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
  59. {
  60. int i,b,f;
  61. for(f = 0; f < n; ++f){
  62. float sum = 0;
  63. for(b = 0; b < batch; ++b){
  64. for(i = 0; i < size; ++i){
  65. int index = i + size*(f + n*b);
  66. sum += delta[index] * x_norm[index];
  67. }
  68. }
  69. scale_updates[f] += sum;
  70. }
  71. }
  72. void mean_delta_cpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
  73. {
  74. int i,j,k;
  75. for(i = 0; i < filters; ++i){
  76. mean_delta[i] = 0;
  77. for (j = 0; j < batch; ++j) {
  78. for (k = 0; k < spatial; ++k) {
  79. int index = j*filters*spatial + i*spatial + k;
  80. mean_delta[i] += delta[index];
  81. }
  82. }
  83. mean_delta[i] *= (-1./sqrt(variance[i] + .00001f));
  84. }
  85. }
  86. void variance_delta_cpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
  87. {
  88. int i,j,k;
  89. for(i = 0; i < filters; ++i){
  90. variance_delta[i] = 0;
  91. for(j = 0; j < batch; ++j){
  92. for(k = 0; k < spatial; ++k){
  93. int index = j*filters*spatial + i*spatial + k;
  94. variance_delta[i] += delta[index]*(x[index] - mean[i]);
  95. }
  96. }
  97. variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.));
  98. }
  99. }
  100. void normalize_delta_cpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
  101. {
  102. int f, j, k;
  103. for(j = 0; j < batch; ++j){
  104. for(f = 0; f < filters; ++f){
  105. for(k = 0; k < spatial; ++k){
  106. int index = j*filters*spatial + f*spatial + k;
  107. delta[index] = delta[index] * 1./(sqrt(variance[f] + .00001f)) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
  108. }
  109. }
  110. }
  111. }
  112. void resize_batchnorm_layer(layer *layer, int w, int h)
  113. {
  114. fprintf(stderr, "Not implemented\n");
  115. }
  116. void forward_batchnorm_layer(layer l, network net)
  117. {
  118. if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1);
  119. copy_cpu(l.outputs*l.batch, l.output, 1, l.x, 1);
  120. if(net.train){
  121. mean_cpu(l.output, l.batch, l.out_c, l.out_h*l.out_w, l.mean);
  122. variance_cpu(l.output, l.mean, l.batch, l.out_c, l.out_h*l.out_w, l.variance);
  123. scal_cpu(l.out_c, .99, l.rolling_mean, 1);
  124. axpy_cpu(l.out_c, .01, l.mean, 1, l.rolling_mean, 1);
  125. scal_cpu(l.out_c, .99, l.rolling_variance, 1);
  126. axpy_cpu(l.out_c, .01, l.variance, 1, l.rolling_variance, 1);
  127. normalize_cpu(l.output, l.mean, l.variance, l.batch, l.out_c, l.out_h*l.out_w);
  128. copy_cpu(l.outputs*l.batch, l.output, 1, l.x_norm, 1);
  129. } else {
  130. normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.out_c, l.out_h*l.out_w);
  131. }
  132. scale_bias(l.output, l.scales, l.batch, l.out_c, l.out_h*l.out_w);
  133. add_bias(l.output, l.biases, l.batch, l.out_c, l.out_h*l.out_w);
  134. }
  135. void backward_batchnorm_layer(layer l, network net)
  136. {
  137. if(!net.train){
  138. l.mean = l.rolling_mean;
  139. l.variance = l.rolling_variance;
  140. }
  141. backward_bias(l.bias_updates, l.delta, l.batch, l.out_c, l.out_w*l.out_h);
  142. backward_scale_cpu(l.x_norm, l.delta, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates);
  143. scale_bias(l.delta, l.scales, l.batch, l.out_c, l.out_h*l.out_w);
  144. mean_delta_cpu(l.delta, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta);
  145. variance_delta_cpu(l.x, l.delta, l.mean, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta);
  146. normalize_delta_cpu(l.x, l.mean, l.variance, l.mean_delta, l.variance_delta, l.batch, l.out_c, l.out_w*l.out_h, l.delta);
  147. if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, l.delta, 1, net.delta, 1);
  148. }
  149. #ifdef GPU
  150. void pull_batchnorm_layer(layer l)
  151. {
  152. cuda_pull_array(l.scales_gpu, l.scales, l.c);
  153. cuda_pull_array(l.rolling_mean_gpu, l.rolling_mean, l.c);
  154. cuda_pull_array(l.rolling_variance_gpu, l.rolling_variance, l.c);
  155. }
  156. void push_batchnorm_layer(layer l)
  157. {
  158. cuda_push_array(l.scales_gpu, l.scales, l.c);
  159. cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.c);
  160. cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.c);
  161. }
  162. void forward_batchnorm_layer_gpu(layer l, network net)
  163. {
  164. if(l.type == BATCHNORM) copy_gpu(l.outputs*l.batch, net.input_gpu, 1, l.output_gpu, 1);
  165. copy_gpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
  166. if (net.train) {
  167. #ifdef CUDNN
  168. float one = 1;
  169. float zero = 0;
  170. cudnnBatchNormalizationForwardTraining(cudnn_handle(),
  171. CUDNN_BATCHNORM_SPATIAL,
  172. &one,
  173. &zero,
  174. l.dstTensorDesc,
  175. l.x_gpu,
  176. l.dstTensorDesc,
  177. l.output_gpu,
  178. l.normTensorDesc,
  179. l.scales_gpu,
  180. l.biases_gpu,
  181. .01,
  182. l.rolling_mean_gpu,
  183. l.rolling_variance_gpu,
  184. .00001,
  185. l.mean_gpu,
  186. l.variance_gpu);
  187. #else
  188. fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
  189. fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
  190. scal_gpu(l.out_c, .99, l.rolling_mean_gpu, 1);
  191. axpy_gpu(l.out_c, .01, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
  192. scal_gpu(l.out_c, .99, l.rolling_variance_gpu, 1);
  193. axpy_gpu(l.out_c, .01, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
  194. copy_gpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
  195. normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
  196. copy_gpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
  197. scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
  198. add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
  199. #endif
  200. } else {
  201. normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
  202. scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
  203. add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
  204. }
  205. }
  206. void backward_batchnorm_layer_gpu(layer l, network net)
  207. {
  208. if(!net.train){
  209. l.mean_gpu = l.rolling_mean_gpu;
  210. l.variance_gpu = l.rolling_variance_gpu;
  211. }
  212. #ifdef CUDNN
  213. float one = 1;
  214. float zero = 0;
  215. cudnnBatchNormalizationBackward(cudnn_handle(),
  216. CUDNN_BATCHNORM_SPATIAL,
  217. &one,
  218. &zero,
  219. &one,
  220. &one,
  221. l.dstTensorDesc,
  222. l.x_gpu,
  223. l.dstTensorDesc,
  224. l.delta_gpu,
  225. l.dstTensorDesc,
  226. l.x_norm_gpu,
  227. l.normTensorDesc,
  228. l.scales_gpu,
  229. l.scale_updates_gpu,
  230. l.bias_updates_gpu,
  231. .00001,
  232. l.mean_gpu,
  233. l.variance_gpu);
  234. copy_gpu(l.outputs*l.batch, l.x_norm_gpu, 1, l.delta_gpu, 1);
  235. #else
  236. backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h);
  237. backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates_gpu);
  238. scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
  239. fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta_gpu);
  240. fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta_gpu);
  241. normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.delta_gpu);
  242. #endif
  243. if(l.type == BATCHNORM) copy_gpu(l.outputs*l.batch, l.delta_gpu, 1, net.delta_gpu, 1);
  244. }
  245. #endif