Skip to content

Latest commit

 

History

History
1089 lines (1059 loc) · 114 KB

detailed_benchmark.md

File metadata and controls

1089 lines (1059 loc) · 114 KB

Detailed MLX benchmark

Detailed runtime benchmark of mlx operations, measured in milliseconds.

  • mlx_gpu: mlx framework with gpu backend
  • mlx_cpu: mlx framework with cpu backend
  • cpu: torch framework with cpu backend
  • mps: torch framework with mps (gpu) backend
  • mlx_gpu/mps speedup: runtime speedup of mlx_gpu compared to mps
  • mlx_gpu/mlx_cpu speedup: runtime speedup of mlx_gpu compared to mlx_cpu
  • cuda/cpu speedup: runtime speedup of cuda compared to cpu

Apple Silicon

M1 (cores: 4E+4P+8GPU)

Operation mlx_gpu mlx_cpu mps cpu mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 2.72 12.43 3.93 20.56 +44% +356%
Argmax / dim=64x1024x128 axi=1 1.55 11.21 3.38 6.23 +117% +621%
Argmax / dim=64x1024x128 axi=2 1.55 8.56 2.69 3.04 +73% +452%
Argmax / dim=64x128x1024 axi=2 1.42 10.32 1.49 2.35 +5% +626%
BCE / dim=1000000 dim=1000000 1.03 7.38 1.64 1.75 +58% +615%
BCE / dim=100000x32 dim=100000x32 2.53 22.57 4.90 4.59 +93% +791%
BCE / dim=100000x64x2 dim=100000x64x2 8.96 87.92 21.47 18.42 +139% +880%
BCE / dim=128x100000 dim=128x100000 9.52 89.39 20.74 18.74 +117% +839%
Concat / dim=1000000x64 dim=1000000x32 axi=1 14.11 64.02 13.84 38.98 -1% +353%
Concat / dim=1000000x64 dim=1000000x128 axi=1 27.23 148.54 27.91 77.07 +2% +445%
Concat / dim=1000000x64 dim=1000000x64 axi=0 17.61 85.42 17.58 41.08 0% +385%
Concat / dim=64x1000000 dim=64x1000000 axi=0 17.57 105.55 17.77 41.40 +1% +500%
Conv1d / dim=100x256x3 dim=8x3x3 1.13 0.33 0.55 2.39 -51% -70%
Conv1d / dim=100x256x256 dim=8x3x256 6.46 9.77 5.85 54.03 -9% +51%
Conv1d / dim=16x1000x80 dim=128x11x80 5.41 7.43 7.07 359.73 +30% +37%
Conv1d / dim=16x1000x3 dim=128x11x3 2.33 0.58 1.45 48.37 -37% -75%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 80.43 935.78 11.06 127.84 -86% +1063%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 29.16 425.12 8.57 32.76 -70% +1358%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 29.64 703.02 9.62 56.69 -67% +2271%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 2.57 10.70 1.11 1.35 -56% +316%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 8.82 108.78 4.94 9.04 -44% +1133%
LeakyReLU / dim=128x16x1024 0.98 1.44 0.67 0.70 -31% +47%
LeakyReLU / dim=64x128x1024 3.14 4.37 1.65 2.05 -47% +39%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 22.13 60.20 13.41 62.32 -39% +171%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 22.15 64.97 21.02 72.80 -5% +193%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 40.84 96.80 79.19 170.94 +93% +137%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 66.18 144.25 153.90 281.38 +132% +117%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.75 0.40 1.00 0.96 +32% -47%
MatMul / dim=32x1x1000 dim=32x1000x128 0.56 0.90 0.94 0.81 +65% +60%
MatMul / dim=1000x64x256 dim=256x32 1.70 4.17 5.03 15.57 +195% +144%
MatMul / dim=1000x64x1024 dim=1000x1024x32 8.28 17.74 21.41 253.21 +158% +114%
MatMul / dim=1000x1024x64 dim=1000x64x256 25.69 93.79 25.91 1506.15 +0% +265%
MatMul / dim=64x1000000 dim=1000000x32 41.69 18.05 33.39 56.56 -19% -56%
MatMul / dim=1000000x64 dim=64x1024 80.34 428.28 200.55 1194.50 +149% +433%
PReLU / dim=128x16x1024 dim=1 1.52 2.48 0.65 0.56 -57% +63%
PReLU / dim=64x128x1024 dim=1 5.48 6.59 1.65 2.07 -69% +20%
ReLU / dim=128x16x1024 0.54 0.53 0.66 0.68 +21% -1%
ReLU / dim=64x128x1024 1.43 1.27 1.60 2.01 +12% -11%
SeLU / dim=128x16x1024 3.34 6.65 0.65 3.26 -80% +98%
SeLU / dim=64x128x1024 12.27 22.81 1.63 12.17 -86% +85%
Sigmoid / dim=128x16x1024 0.51 13.28 0.67 3.24 +31% +2511%
Sigmoid / dim=64x128x1024 1.42 52.05 1.64 11.21 +15% +3565%
Softmax / dim=64x1000000 axi=-1 13.33 36.49 17.54 55.15 +31% +173%
Softmax / dim=1000000x64 axi=-1 8.85 34.18 21.06 59.54 +138% +286%
Softmax / dim=64x16x32x1024 axi=-1 4.81 18.03 16.84 23.87 +250% +274%
Softmax / dim=128x16x32x1024 axi=-1 9.31 36.18 31.50 54.39 +238% +288%
Softmax / dim=1024x16x32x128 axi=-1 9.43 34.99 22.24 57.47 +135% +271%
Softmax / dim=1024x64x32x8 axi=-1 15.19 86.01 6.42 29.72 -57% +466%
Softplus / dim=128x16x1024 0.69 13.48 0.94 4.55 +35% +1860%
Softplus / dim=64x128x1024 1.46 52.67 2.52 17.42 +72% +3503%
Sort / dim=64x128x1024 axi=0 34.19 900.00 95.15 72.40 +178% +2532%
Sort / dim=64x128x1024 axi=1 18.08 976.16 85.37 63.35 +372% +5298%
Sort / dim=64x128x1024 axi=2 3.20 263.54 39.22 74.59 +1123% +8124%
Sum / dim=64x128x128x128 axi=0 11.36 17.39 30.79 17.97 +170% +53%
Sum / dim=64x128x128x128 axi=1 9.74 12.66 10.87 15.02 +11% +30%
Sum / dim=64x128x128x128 axi=2 9.39 11.23 10.52 10.80 +11% +19%
Sum / dim=64x128x128x128 axi=3 14.82 9.51 13.21 9.90 -10% -35%
SumAll / dim=64x128x128x128 9.02 8.81 9.82 9.38 +8% -2%
SumAll / dim=1000000 0.52 0.08 0.43 0.07 -17% -85%
SumAll / dim=1000000x128 9.13 9.22 9.89 9.34 +8% +0%
SumAll / dim=128x1000000 8.95 9.28 9.46 9.18 +5% +3%

M1 Pro (2E+8P+16GPU+16GB) - mlx: 0.5.0

Operation mlx_gpu mlx_gpu_compile mlx_cpu mps cpu mlx_gpu_compile/mlx_gpu speedup mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 1.77 1.71 10.57 1.39 23.23 +3% -21% +498%
Argmax / dim=64x1024x128 axi=1 1.76 1.71 10.53 1.43 5.46 +2% -18% +498%
Argmax / dim=64x1024x128 axi=2 1.74 1.85 10.54 0.69 2.19 -6% -60% +507%
Argmax / dim=64x128x1024 axi=2 1.74 1.69 10.57 0.56 1.90 +2% -67% +508%
BCE / dim=1000000 dim=1000000 0.46 0.39 7.89 0.49 1.28 +16% +6% +1628%
BCE / dim=100000x32 dim=100000x32 1.07 0.54 25.53 0.64 3.47 +97% -39% +2296%
BCE / dim=100000x64x2 dim=100000x64x2 3.63 1.47 102.71 1.12 14.53 +146% -69% +2728%
BCE / dim=128x100000 dim=128x100000 3.57 1.46 101.88 1.11 14.64 +144% -68% +2757%
Concat / dim=1000000x64 dim=1000000x32 axi=1 4.38 4.43 65.75 4.52 25.64 -1% +3% +1401%
Concat / dim=1000000x64 dim=1000000x128 axi=1 8.48 8.52 144.59 8.62 45.77 0% +1% +1605%
Concat / dim=1000000x64 dim=1000000x64 axi=0 5.96 5.75 60.70 5.85 37.96 +3% -1% +919%
Concat / dim=64x1000000 dim=64x1000000 axi=0 5.74 5.80 80.50 5.86 37.60 -1% +2% +1303%
Conv1d / dim=100x256x3 dim=8x3x3 0.53 0.36 0.36 0.34 2.72 +48% -34% -30%
Conv1d / dim=100x256x256 dim=8x3x256 3.11 3.06 8.35 1.52 74.37 +1% -51% +168%
Conv1d / dim=16x1000x80 dim=128x11x80 2.78 2.72 4.39 1.79 485.21 +2% -35% +58%
Conv1d / dim=16x1000x3 dim=128x11x3 0.61 0.41 0.57 0.40 55.22 +48% -34% -7%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 8.54 8.46 976.61 6.35 136.53 +1% -25% +11330%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 17.35 17.26 420.71 2.52 26.52 +0% -85% +2324%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 1.11 1.11 706.93 2.35 37.35 0% +112% +63694%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 0.34 0.33 10.45 0.38 1.44 +1% +12% +3001%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 1.22 1.17 104.42 1.01 8.77 +4% -17% +8448%
Gather / dim=64x256 dim=10 0.20 0.21 0.01 0.17 0.01 -5% -14% -93%
Gather / dim=64x256 dim=1000 0.24 0.26 0.04 0.25 0.12 -7% +5% -84%
Gather / dim=64x256 dim=1000000 15.77 15.75 22.90 99.52 46.30 +0% +531% +45%
Gather / dim=1024x32 dim=10 0.20 0.20 0.01 0.22 0.00 +1% +7% -94%
Gather / dim=1024x32 dim=1000 0.20 0.22 0.02 0.23 0.11 -7% +14% -90%
Gather / dim=1024x32 dim=1000000 2.30 2.38 6.73 12.83 7.66 -3% +458% +193%
LeakyReLU / dim=128x16x1024 0.31 0.28 0.32 0.32 0.45 +13% +3% +1%
LeakyReLU / dim=64x128x1024 0.60 0.59 1.27 0.61 1.98 +0% +1% +113%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 3.12 3.23 24.39 8.32 63.76 -3% +166% +682%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 4.64 4.62 31.95 13.27 73.28 +0% +186% +589%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 13.76 14.31 46.53 49.23 156.42 -3% +257% +238%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 25.94 25.99 69.94 94.68 341.82 0% +265% +169%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.41 0.64 0.43 0.54 3.84 -36% +31% +4%
MatMul / dim=32x1x1000 dim=32x1000x128 0.27 0.27 0.89 0.30 0.82 -1% +13% +232%
MatMul / dim=1000x64x256 dim=256x32 0.99 0.92 1.65 1.28 27.50 +7% +29% +66%
MatMul / dim=1000x64x1024 dim=1000x1024x32 3.05 3.06 18.26 13.78 359.44 0% +351% +498%
MatMul / dim=1000x1024x64 dim=1000x64x256 9.30 9.28 42.73 9.05 1088.01 +0% -2% +359%
MatMul / dim=64x1000000 dim=1000000x32 14.72 15.05 10.64 19.92 211.27 -2% +35% -27%
MatMul / dim=1000000x64 dim=64x1024 34.80 35.31 155.57 92.25 1305.17 -1% +165% +346%
PReLU / dim=128x16x1024 dim=1 0.36 0.34 1.35 0.43 0.42 +6% +18% +271%
PReLU / dim=64x128x1024 dim=1 0.59 0.58 5.38 0.67 1.71 +1% +13% +815%
ReLU / dim=128x16x1024 0.34 0.26 0.21 0.46 0.45 +30% +34% -38%
ReLU / dim=64x128x1024 0.59 0.59 1.04 0.65 1.72 +0% +9% +75%
Scatter / dim=64x16 dim=10 0.20 0.19 0.01 0.21 0.00 +6% +1% -92%
Scatter / dim=64x16 dim=1000 0.22 0.19 0.07 0.19 0.06 +19% -13% -66%
Scatter / dim=64x16 dim=1000000 1.06 1.06 62.23 6.51 4.35 0% +516% +5793%
Scatter / dim=1024x32 dim=10 0.25 0.20 0.02 0.23 0.00 +23% -4% -92%
Scatter / dim=1024x32 dim=1000 0.21 0.21 0.13 0.22 0.07 +2% +2% -36%
Scatter / dim=1024x32 dim=1000000 1.59 1.59 117.67 12.90 7.14 0% +709% +7290%
ScatterSum / dim=64x16 dim=10 0.04 0.04 0.01 nan 0.00 +17% nan% -70%
ScatterSum / dim=64x16 dim=1000 0.04 0.03 0.01 nan 0.01 +11% nan% -73%
ScatterSum / dim=64x16 dim=1000000 0.03 0.03 0.01 nan 1.70 +9% nan% -71%
ScatterSum / dim=1024x32 dim=10 0.03 0.03 0.01 nan 0.01 +11% nan% -70%
ScatterSum / dim=1024x32 dim=1000 0.03 0.03 0.01 nan 0.01 +8% nan% -67%
ScatterSum / dim=1024x32 dim=1000000 0.03 0.09 0.01 nan 7.09 -63% nan% -71%
ScatterMax / dim=64x16 dim=10 0.03 0.03 0.01 nan 0.00 -7% nan% -68%
ScatterMax / dim=64x16 dim=1000 0.03 0.03 0.01 nan 0.01 +9% nan% -72%
ScatterMax / dim=64x16 dim=1000000 0.03 0.03 0.01 nan 1.66 +6% nan% -70%
ScatterMax / dim=1024x32 dim=10 0.03 0.03 0.01 nan 0.01 -9% nan% -61%
ScatterMax / dim=1024x32 dim=1000 0.04 0.03 0.01 nan 0.01 +23% nan% -71%
ScatterMax / dim=1024x32 dim=1000000 0.03 0.06 0.01 nan 6.98 -47% nan% -70%
SeLU / dim=128x16x1024 0.43 0.32 1.95 0.31 2.68 +34% -27% +358%
SeLU / dim=64x128x1024 0.60 0.60 7.77 0.62 10.76 +0% +3% +1197%
Sigmoid / dim=128x16x1024 0.29 0.29 1.87 0.36 2.46 +1% +23% +540%
Sigmoid / dim=64x128x1024 0.60 0.59 7.29 0.74 10.32 +0% +23% +1122%
Softmax / dim=64x1000000 axi=-1 11.51 8.80 50.97 5.94 33.72 +30% -48% +342%
Softmax / dim=1000000x64 axi=-1 11.52 8.74 51.12 8.64 36.07 +31% -25% +343%
Softmax / dim=64x16x32x1024 axi=-1 6.18 4.83 27.12 3.32 18.95 +28% -46% +338%
Softmax / dim=128x16x32x1024 axi=-1 12.06 9.20 53.53 6.27 33.52 +31% -48% +344%
Softmax / dim=1024x16x32x128 axi=-1 12.07 9.25 53.57 9.18 34.96 +30% -23% +343%
Softmax / dim=1024x64x32x8 axi=-1 3.27 3.12 13.67 2.44 24.18 +4% -25% +318%
Softplus / dim=128x16x1024 0.32 0.31 14.10 0.33 3.66 +2% +1% +4301%
Softplus / dim=64x128x1024 0.59 0.67 56.43 0.65 14.27 -12% +9% +9461%
Sort / dim=64x128x1024 axi=0 1.70 1.77 259.52 51.80 63.23 -3% +2939% +15131%
Sort / dim=64x128x1024 axi=1 1.69 1.69 258.09 44.60 51.60 +0% +2537% +15162%
Sort / dim=64x128x1024 axi=2 1.69 1.69 257.45 16.87 60.86 0% +901% +15174%
Sum / dim=64x128x128x128 axi=0 3.38 3.46 9.19 11.78 13.63 -2% +248% +171%
Sum / dim=64x128x128x128 axi=1 3.37 3.47 9.21 3.26 13.42 -2% -3% +173%
Sum / dim=64x128x128x128 axi=2 3.40 3.46 9.36 3.25 7.71 -1% -4% +175%
Sum / dim=64x128x128x128 axi=3 3.38 3.46 9.23 5.97 5.31 -2% +76% +173%
SumAll / dim=64x128x128x128 3.37 3.72 9.17 3.31 4.67 -9% -1% +172%
SumAll / dim=1000000 0.26 0.26 0.06 0.40 0.10 +1% +53% -76%
SumAll / dim=1000000x128 3.24 3.26 9.32 3.13 4.51 0% -3% +188%
SumAll / dim=128x1000000 3.21 3.30 8.77 3.08 4.56 -2% -3% +173%

M1 Max (64GB) mlx 0.2.0

Operation mlx_gpu mlx_gpu_compile mlx_cpu mps cpu mlx_gpu_compile/mlx_gpu speedup mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 2.08 1.67 11.14 4.58 25.92 +25% +119% +434%
Argmax / dim=64x1024x128 axi=1 2.20 1.72 10.85 1.48 6.47 +27% -32% +393%
Argmax / dim=64x1024x128 axi=2 2.19 1.68 10.62 1.01 2.35 +30% -53% +385%
Argmax / dim=64x128x1024 axi=2 2.11 1.69 10.58 0.63 1.92 +25% -70% +401%
BCE / dim=1000000 dim=1000000 0.52 0.35 6.62 0.48 1.36 +50% -8% +1172%
BCE / dim=100000x32 dim=100000x32 0.71 0.45 21.63 0.69 3.47 +56% -3% +2953%
BCE / dim=100000x64x2 dim=100000x64x2 1.98 0.91 86.42 1.47 13.99 +118% -25% +4266%
BCE / dim=128x100000 dim=128x100000 1.98 0.91 86.40 1.39 13.53 +117% -29% +4268%
Concat / dim=1000000x64 dim=1000000x32 axi=1 2.34 2.33 78.58 2.41 18.34 +0% +2% +3260%
Concat / dim=1000000x64 dim=1000000x128 axi=1 4.43 4.41 146.43 4.52 42.51 +0% +2% +3206%
Concat / dim=1000000x64 dim=1000000x64 axi=0 3.02 3.02 62.03 3.07 19.04 +0% +1% +1952%
Concat / dim=64x1000000 dim=64x1000000 axi=0 3.03 3.02 82.37 3.08 19.29 +0% +1% +2622%
Conv1d / dim=100x256x3 dim=8x3x3 0.39 0.39 0.36 0.40 2.57 +0% +0% -7%
Conv1d / dim=100x256x256 dim=8x3x256 4.05 1.63 8.22 1.80 69.41 +149% -55% +102%
Conv1d / dim=16x1000x80 dim=128x11x80 2.10 1.38 4.20 1.87 497.10 +52% -10% +100%
Conv1d / dim=16x1000x3 dim=128x11x3 2.27 0.53 0.59 0.64 59.95 +326% -71% -74%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 18.90 18.76 1034.14 3.49 131.73 +0% -81% +5370%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 9.18 6.97 419.33 2.23 18.17 +31% -75% +4467%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 9.23 7.42 697.01 2.21 18.67 +24% -76% +7449%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 0.98 0.78 10.66 0.53 1.45 +25% -46% +988%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 2.60 2.25 116.24 1.46 7.77 +15% -43% +4363%
Gather / dim=64x256 dim=10 0.23 0.23 0.01 0.23 0.01 0% 0% -94%
Gather / dim=64x256 dim=1000 0.33 0.32 0.03 0.33 0.12 +4% +0% -89%
Gather / dim=64x256 dim=1000000 11.57 11.54 28.36 50.98 46.58 +0% +340% +145%
Gather / dim=1024x32 dim=10 0.25 0.23 0.01 0.22 0.00 +7% -8% -94%
Gather / dim=1024x32 dim=1000 0.26 0.25 0.02 0.25 0.09 +3% -4% -92%
Gather / dim=1024x32 dim=1000000 2.42 1.66 7.20 6.66 6.73 +45% +175% +197%
LeakyReLU / dim=128x16x1024 0.40 0.30 1.79 0.37 0.68 +31% -8% +347%
LeakyReLU / dim=64x128x1024 0.67 0.37 7.02 0.54 0.59 +83% -19% +941%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 4.35 3.74 19.60 3.40 31.02 +16% -21% +350%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 4.71 4.34 27.89 5.88 37.61 +8% +24% +492%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 8.88 8.86 49.44 24.42 53.09 +0% +174% +456%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 14.83 14.81 64.98 47.81 75.07 +0% +222% +338%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.85 0.71 0.37 0.67 0.44 +19% -21% -56%
MatMul / dim=32x1x1000 dim=32x1000x128 0.31 0.26 0.84 0.40 0.87 +17% +29% +173%
MatMul / dim=1000x64x256 dim=256x32 0.62 0.60 1.71 2.02 2.31 +3% +225% +176%
MatMul / dim=1000x64x1024 dim=1000x1024x32 1.71 1.70 17.97 7.38 20.63 +0% +331% +949%
MatMul / dim=1000x1024x64 dim=1000x64x256 4.84 4.83 64.09 4.96 122.87 +0% +2% +1224%
MatMul / dim=64x1000000 dim=1000000x32 2.87 2.88 10.51 11.88 10.47 0% +313% +265%
MatMul / dim=1000000x64 dim=64x1024 17.58 17.56 187.87 40.33 372.75 +0% +129% +968%
PReLU / dim=128x16x1024 dim=1 0.57 0.33 1.09 0.36 0.55 +73% -36% +93%
PReLU / dim=64x128x1024 dim=1 1.08 0.39 4.19 0.52 0.59 +172% -51% +289%
ReLU / dim=128x16x1024 0.32 0.29 0.31 0.37 0.63 +10% +15% -3%
ReLU / dim=64x128x1024 0.41 0.38 1.33 0.52 0.58 +8% +26% +224%
Scatter / dim=64x16 dim=10 0.23 0.24 0.01 0.22 0.00 -5% -5% -93%
Scatter / dim=64x16 dim=1000 0.31 0.29 0.07 0.25 0.05 +4% -18% -75%
Scatter / dim=64x16 dim=1000000 7.99 7.94 62.95 3.41 4.35 +0% -57% +687%
Scatter / dim=1024x32 dim=10 0.28 0.26 0.02 0.22 0.00 +6% -19% -93%
Scatter / dim=1024x32 dim=1000 0.31 0.29 0.13 0.26 0.08 +6% -16% -57%
Scatter / dim=1024x32 dim=1000000 15.54 15.51 118.69 6.72 6.19 +0% -56% +663%
ScatterSum / dim=64x16 dim=10 0.05 0.03 0.01 nan 0.00 +46% nan% -82%
ScatterSum / dim=64x16 dim=1000 0.05 0.03 0.01 nan 0.01 +41% nan% -80%
ScatterSum / dim=64x16 dim=1000000 0.05 0.03 0.01 nan 1.60 +54% nan% -83%
ScatterSum / dim=1024x32 dim=10 0.04 0.03 0.01 nan 0.01 +35% nan% -78%
ScatterSum / dim=1024x32 dim=1000 0.05 0.03 0.01 nan 0.01 +41% nan% -81%
ScatterSum / dim=1024x32 dim=1000000 0.05 0.03 0.01 nan 6.49 +38% nan% -81%
ScatterMax / dim=64x16 dim=10 0.05 0.03 0.01 nan 0.00 +41% nan% -82%
ScatterMax / dim=64x16 dim=1000 0.05 0.03 0.01 nan 0.01 +37% nan% -82%
ScatterMax / dim=64x16 dim=1000000 0.05 0.03 0.01 nan 1.54 +37% nan% -80%
ScatterMax / dim=1024x32 dim=10 0.04 0.03 0.01 nan 0.01 +26% nan% -79%
ScatterMax / dim=1024x32 dim=1000 0.05 0.03 0.01 nan 0.01 +36% nan% -81%
ScatterMax / dim=1024x32 dim=1000000 0.04 0.03 0.01 nan 6.55 +29% nan% -79%
SeLU / dim=128x16x1024 0.98 0.33 2.75 0.36 2.62 +202% -63% +178%
SeLU / dim=64x128x1024 2.07 0.40 11.36 0.57 9.32 +423% -72% +449%
Sigmoid / dim=128x16x1024 0.34 0.32 13.10 0.40 2.43 +4% +18% +3784%
Sigmoid / dim=64x128x1024 0.43 0.39 52.05 0.60 8.43 +9% +39% +12049%
Softmax / dim=64x1000000 axi=-1 5.86 4.49 53.32 3.30 33.16 +30% -43% +809%
Softmax / dim=1000000x64 axi=-1 5.87 4.48 53.33 4.89 34.35 +31% -16% +807%
Softmax / dim=64x16x32x1024 axi=-1 3.25 2.50 28.15 3.18 19.36 +29% -2% +766%
Softmax / dim=128x16x32x1024 axi=-1 6.13 4.68 55.99 5.01 31.86 +30% -18% +813%
Softmax / dim=1024x16x32x128 axi=-1 6.13 4.68 55.89 5.14 33.61 +31% -16% +811%
Softmax / dim=1024x64x32x8 axi=-1 1.79 1.44 14.22 1.79 21.23 +24% 0% +692%
Softplus / dim=128x16x1024 0.52 0.32 13.09 0.47 3.57 +62% -9% +2405%
Softplus / dim=64x128x1024 0.62 0.37 52.49 0.87 12.93 +68% +40% +8373%
Sort / dim=64x128x1024 axi=0 1.08 0.97 257.13 29.79 52.22 +11% +2661% +23735%
Sort / dim=64x128x1024 axi=1 1.09 0.98 257.19 21.35 43.94 +11% +1862% +23537%
Sort / dim=64x128x1024 axi=2 1.06 0.97 257.20 9.39 51.75 +9% +782% +24073%
Sum / dim=64x128x128x128 axi=0 1.75 1.74 8.96 5.57 15.39 +0% +218% +411%
Sum / dim=64x128x128x128 axi=1 1.76 1.74 8.84 1.80 13.53 +0% +2% +403%
Sum / dim=64x128x128x128 axi=2 1.74 1.74 8.84 1.79 7.31 +0% +2% +406%
Sum / dim=64x128x128x128 axi=3 1.74 1.74 8.83 3.25 5.18 +0% +86% +406%
SumAll / dim=64x128x128x128 1.74 1.73 8.82 1.84 4.50 +0% +5% +405%
SumAll / dim=1000000 0.32 0.30 0.06 0.36 0.10 +7% +15% -81%
SumAll / dim=1000000x128 1.68 1.67 8.68 1.93 4.47 +1% +14% +415%
SumAll / dim=128x1000000 1.68 1.67 8.95 1.87 4.35 +0% +11% +434%

M2 () - mlx 0.2.0

Operation mlx_gpu mlx_cpu mps cpu mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 2.57 16.16 3.93 22.12 +52% +528%
Argmax / dim=64x1024x128 axi=1 1.47 18.03 3.35 6.93 +127% +1126%
Argmax / dim=64x1024x128 axi=2 1.44 13.91 2.39 2.67 +65% +866%
Argmax / dim=64x128x1024 axi=2 1.35 16.74 1.18 2.32 -12% +1140%
BCE / dim=1000000 dim=1000000 0.83 11.36 1.93 1.96 +132% +1267%
BCE / dim=100000x32 dim=100000x32 1.84 35.85 6.22 5.58 +238% +1849%
BCE / dim=100000x64x2 dim=100000x64x2 6.08 140.95 21.14 21.82 +247% +2219%
BCE / dim=128x100000 dim=128x100000 6.10 141.21 22.96 23.99 +276% +2215%
Concat / dim=1000000x64 dim=1000000x32 axi=1 8.65 99.86 8.95 34.82 +3% +1054%
Concat / dim=1000000x64 dim=1000000x128 axi=1 17.19 221.80 17.74 62.82 +3% +1190%
Concat / dim=1000000x64 dim=1000000x64 axi=0 11.35 142.18 11.57 44.15 +1% +1152%
Concat / dim=64x1000000 dim=64x1000000 axi=0 11.37 180.44 11.78 43.52 +3% +1487%
Conv1d / dim=100x256x3 dim=8x3x3 0.56 0.56 0.59 3.62 +5% 0%
Conv1d / dim=100x256x256 dim=8x3x256 5.08 13.73 5.53 56.72 +8% +170%
Conv1d / dim=16x1000x80 dim=128x11x80 6.01 8.79 5.61 392.02 -6% +46%
Conv1d / dim=16x1000x3 dim=128x11x3 3.00 0.95 1.44 78.40 -51% -68%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 77.90 1491.19 10.80 173.72 -86% +1814%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 25.70 690.25 7.00 42.99 -72% +2585%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 24.70 1154.58 6.52 55.78 -73% +4574%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 2.00 17.35 1.17 1.87 -41% +769%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 8.41 175.54 4.20 10.03 -50% +1987%
LeakyReLU / dim=128x16x1024 0.80 1.49 0.70 0.92 -12% +84%
LeakyReLU / dim=64x128x1024 2.20 3.39 1.50 1.75 -31% +54%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 16.90 74.45 14.00 83.08 -17% +340%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 16.82 90.75 23.52 104.08 +39% +439%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 33.88 131.98 83.12 255.48 +145% +289%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 56.74 201.06 164.84 471.68 +190% +254%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.72 0.30 0.99 0.78 +37% -57%
MatMul / dim=32x1x1000 dim=32x1000x128 0.46 0.26 0.85 1.35 +85% -42%
MatMul / dim=1000x64x256 dim=256x32 1.14 4.86 5.17 13.63 +351% +324%
MatMul / dim=1000x64x1024 dim=1000x1024x32 4.84 24.73 19.62 188.50 +305% +410%
MatMul / dim=1000x1024x64 dim=1000x64x256 17.95 110.02 20.34 1452.25 +13% +512%
MatMul / dim=64x1000000 dim=1000000x32 39.95 21.26 33.21 49.77 -16% -46%
MatMul / dim=1000000x64 dim=64x1024 67.87 562.54 389.42 2072.27 +473% +728%
PReLU / dim=128x16x1024 dim=1 1.18 2.88 0.71 0.95 -39% +145%
PReLU / dim=64x128x1024 dim=1 3.69 6.27 1.37 1.74 -62% +69%
ReLU / dim=128x16x1024 0.46 0.66 0.67 0.93 +44% +42%
ReLU / dim=64x128x1024 1.08 1.33 1.34 1.76 +23% +23%
SeLU / dim=128x16x1024 2.39 7.83 0.71 3.37 -70% +227%
SeLU / dim=64x128x1024 8.22 26.67 1.51 13.12 -81% +224%
Sigmoid / dim=128x16x1024 0.46 21.21 0.76 3.12 +66% +4535%
Sigmoid / dim=64x128x1024 1.07 84.49 1.50 11.83 +39% +7760%
Softmax / dim=64x1000000 axi=-1 8.80 55.52 11.74 71.44 +33% +530%
Softmax / dim=1000000x64 axi=-1 6.00 55.61 16.18 79.75 +169% +827%
Softmax / dim=64x16x32x1024 axi=-1 3.23 28.86 12.77 28.85 +294% +792%
Softmax / dim=128x16x32x1024 axi=-1 6.16 57.68 23.52 75.04 +282% +836%
Softmax / dim=1024x16x32x128 axi=-1 6.13 55.61 17.27 75.92 +181% +807%
Softmax / dim=1024x64x32x8 axi=-1 12.10 140.43 5.74 34.53 -52% +1060%
Softplus / dim=128x16x1024 0.61 21.86 1.01 5.15 +64% +3457%
Softplus / dim=64x128x1024 1.20 86.03 2.45 19.10 +104% +7069%
Sort / dim=64x128x1024 axi=0 31.34 1619.90 56.94 78.79 +81% +5068%
Sort / dim=64x128x1024 axi=1 16.30 1686.63 52.80 73.93 +224% +10249%
Sort / dim=64x128x1024 axi=2 2.97 423.22 30.32 85.22 +920% +14145%
Sum / dim=64x128x128x128 axi=0 10.23 25.07 18.06 18.94 +76% +144%
Sum / dim=64x128x128x128 axi=1 6.19 19.27 6.41 17.45 +3% +211%
Sum / dim=64x128x128x128 axi=2 6.17 17.14 6.26 11.48 +1% +177%
Sum / dim=64x128x128x128 axi=3 14.01 12.03 11.15 8.90 -20% -14%
SumAll / dim=64x128x128x128 5.75 10.66 6.72 8.27 +17% +85%
SumAll / dim=1000000 0.38 0.10 0.48 0.10 +27% -73%
SumAll / dim=1000000x128 5.53 10.20 6.37 8.07 +15% +84%
SumAll / dim=128x1000000 5.61 10.20 6.26 7.99 +11% +81%

M2 Pro (cores: 4E+6P+16GPU) mlx 0.12.2 torch 2.1.2

Operation mlx_gpu mlx_gpu_compile mlx_cpu mps cpu mlx_gpu_compile/mlx_gpu speedup mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 1.62 1.52 9.94 2.21 22.96 +6% +36% +514%
Argmax / dim=64x1024x128 axi=1 1.55 1.52 9.94 1.07 4.71 +2% -30% +542%
Argmax / dim=64x1024x128 axi=2 1.52 1.52 10.08 1.24 2.03 0% -17% +564%
Argmax / dim=64x128x1024 axi=2 1.53 1.52 9.96 0.60 1.83 +0% -60% +552%
BCE / dim=1000000 dim=1000000 0.34 0.28 7.84 0.48 1.20 +21% +39% +2178%
BCE / dim=100000x32 dim=100000x32 0.91 0.43 25.49 0.51 3.21 +113% -43% +2691%
BCE / dim=100000x64x2 dim=100000x64x2 3.42 1.29 102.36 1.02 13.53 +164% -70% +2890%
BCE / dim=128x100000 dim=128x100000 3.43 1.29 102.44 0.93 14.10 +166% -72% +2885%
Concat / dim=1000000x64 dim=1000000x32 axi=1 4.69 4.61 72.12 4.47 25.65 +1% -4% +1439%
Concat / dim=1000000x64 dim=1000000x128 axi=1 8.53 9.29 139.19 8.77 46.38 -8% +2% +1531%
Concat / dim=1000000x64 dim=1000000x64 axi=0 5.75 6.10 56.55 5.91 36.72 -5% +2% +883%
Concat / dim=64x1000000 dim=64x1000000 axi=0 5.71 5.69 77.44 5.87 37.16 +0% +2% +1256%
Conv1d / dim=100x256x3 dim=8x3x3 0.47 0.28 0.34 0.38 2.57 +66% -19% -28%
Conv1d / dim=100x256x256 dim=8x3x256 2.83 2.83 7.68 1.20 66.92 0% -57% +171%
Conv1d / dim=16x1000x80 dim=128x11x80 2.55 2.49 3.68 1.51 459.58 +2% -40% +44%
Conv1d / dim=16x1000x3 dim=128x11x3 0.63 0.35 0.49 0.51 60.69 +78% -19% -21%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 7.85 7.83 908.32 5.14 142.65 +0% -34% +11470%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 15.94 15.94 390.84 2.29 28.57 +0% -85% +2351%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 0.92 0.91 647.13 1.90 37.38 +1% +106% +70295%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 0.25 0.23 9.88 0.34 1.50 +8% +37% +3866%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 1.06 1.05 98.79 0.90 8.27 +0% -15% +9220%
Gather / dim=64x256 dim=10 0.16 0.15 0.02 0.37 0.00 +9% +133% -86%
Gather / dim=64x256 dim=1000 0.16 0.17 0.04 0.22 0.09 -4% +37% -73%
Gather / dim=64x256 dim=1000000 15.34 15.29 19.20 83.49 50.93 +0% +444% +25%
Gather / dim=1024x32 dim=10 0.14 0.14 0.02 0.19 0.00 -2% +37% -85%
Gather / dim=1024x32 dim=1000 0.14 0.15 0.02 0.21 0.07 -6% +47% -82%
Gather / dim=1024x32 dim=1000000 2.21 2.19 5.47 10.47 7.61 +1% +374% +147%
LeakyReLU / dim=128x16x1024 0.20 0.20 0.63 0.26 0.35 -1% +34% +221%
LeakyReLU / dim=64x128x1024 0.53 0.53 1.16 0.61 1.51 0% +14% +119%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 3.46 3.15 19.01 6.10 73.56 +9% +76% +449%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 4.70 4.68 22.28 11.83 82.54 +0% +151% +373%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 13.16 13.17 37.06 46.17 152.52 0% +250% +181%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 24.89 24.90 56.73 92.15 260.77 0% +270% +127%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.62 0.55 0.24 0.43 6.08 +11% -30% -61%
MatMul / dim=32x1x1000 dim=32x1000x128 0.17 0.20 0.14 0.25 0.72 -15% +49% -15%
MatMul / dim=1000x64x256 dim=256x32 0.89 0.81 2.03 1.09 37.03 +9% +22% +128%
MatMul / dim=1000x64x1024 dim=1000x1024x32 2.89 2.88 16.20 12.88 601.25 +0% +345% +460%
MatMul / dim=1000x1024x64 dim=1000x64x256 8.88 8.89 51.04 8.62 2242.84 0% -3% +474%
MatMul / dim=64x1000000 dim=1000000x32 3.36 3.34 9.31 22.10 149.75 +0% +556% +176%
MatMul / dim=1000000x64 dim=64x1024 49.39 43.20 135.55 84.62 1493.04 +14% +71% +174%
PReLU / dim=128x16x1024 dim=1 0.50 0.24 1.76 0.27 0.37 +105% -45% +251%
PReLU / dim=64x128x1024 dim=1 0.56 0.53 5.15 0.60 1.45 +4% +7% +821%
ReLU / dim=128x16x1024 0.30 0.21 0.48 0.28 0.37 +41% -6% +59%
ReLU / dim=64x128x1024 0.52 0.52 0.99 0.58 1.48 0% +11% +90%
Scatter / dim=64x16 dim=10 0.14 0.15 0.02 0.11 0.00 -3% -20% -86%
Scatter / dim=64x16 dim=1000 0.14 0.15 0.08 0.11 0.05 -5% -18% -41%
Scatter / dim=64x16 dim=1000000 0.48 0.49 58.67 5.43 4.97 -1% +1032% +12151%
Scatter / dim=1024x32 dim=10 0.15 0.15 0.02 0.21 0.00 -1% +40% -83%
Scatter / dim=1024x32 dim=1000 0.15 0.14 0.13 0.14 0.07 +3% -1% -9%
Scatter / dim=1024x32 dim=1000000 0.82 0.80 110.60 10.62 8.78 +1% +1199% +13425%
ScatterSum / dim=64x16 dim=10 0.04 0.04 0.02 nan 0.00 +6% nan% -44%
ScatterSum / dim=64x16 dim=1000 0.04 0.04 0.02 nan 0.01 +7% nan% -54%
ScatterSum / dim=64x16 dim=1000000 0.03 0.04 0.02 nan 1.55 -3% nan% -49%
ScatterSum / dim=1024x32 dim=10 0.03 0.03 0.02 nan 0.01 +1% nan% -49%
ScatterSum / dim=1024x32 dim=1000 0.03 0.03 0.02 nan 0.01 +5% nan% -49%
ScatterSum / dim=1024x32 dim=1000000 0.04 0.03 0.02 nan 6.70 +4% nan% -51%
ScatterMax / dim=64x16 dim=10 0.03 0.03 0.02 nan 0.00 +4% nan% -47%
ScatterMax / dim=64x16 dim=1000 0.04 0.03 0.02 nan 0.00 +15% nan% -53%
ScatterMax / dim=64x16 dim=1000000 0.03 0.03 0.02 nan 1.53 +4% nan% -48%
ScatterMax / dim=1024x32 dim=10 0.03 0.03 0.02 nan 0.01 +7% nan% -44%
ScatterMax / dim=1024x32 dim=1000 0.04 0.04 0.02 nan 0.01 +8% nan% -50%
ScatterMax / dim=1024x32 dim=1000000 0.04 0.03 0.02 nan 6.71 +2% nan% -51%
SeLU / dim=128x16x1024 0.43 0.32 2.34 0.38 1.11 +35% -11% +442%
SeLU / dim=64x128x1024 0.54 0.54 7.36 0.64 4.21 +0% +17% +1264%
Sigmoid / dim=128x16x1024 0.21 0.22 1.73 0.32 0.91 -3% +48% +705%
Sigmoid / dim=64x128x1024 0.53 0.52 6.93 0.63 3.55 +3% +18% +1196%
Softmax / dim=64x1000000 axi=-1 11.36 8.52 48.40 5.85 30.57 +33% -48% +326%
Softmax / dim=1000000x64 axi=-1 11.31 8.51 48.41 5.98 30.12 +32% -47% +328%
Softmax / dim=64x16x32x1024 axi=-1 6.01 4.54 26.38 3.21 13.34 +32% -46% +338%
Softmax / dim=128x16x32x1024 axi=-1 11.79 8.94 51.04 6.29 27.94 +31% -46% +332%
Softmax / dim=1024x16x32x128 axi=-1 11.92 8.96 51.12 6.02 29.37 +33% -49% +328%
Softmax / dim=1024x64x32x8 axi=-1 3.10 2.48 13.00 1.94 18.65 +25% -37% +318%
Softplus / dim=128x16x1024 0.28 0.20 13.77 0.39 1.89 +36% +40% +4836%
Softplus / dim=64x128x1024 0.53 0.54 53.73 0.55 7.57 -1% +3% +9993%
Sort / dim=64x128x1024 axi=0 1.48 1.49 243.85 24.49 56.65 0% +1550% +16337%
Sort / dim=64x128x1024 axi=1 1.48 1.48 242.90 26.92 45.41 0% +1722% +16343%
Sort / dim=64x128x1024 axi=2 1.48 1.48 240.89 15.80 53.13 0% +969% +16204%
Sum / dim=64x128x128x128 axi=0 3.25 3.22 9.24 3.17 14.01 +1% -2% +183%
Sum / dim=64x128x128x128 axi=1 3.23 3.22 9.16 3.15 13.28 +0% -2% +183%
Sum / dim=64x128x128x128 axi=2 3.29 3.23 9.06 3.11 8.24 +1% -5% +175%
Sum / dim=64x128x128x128 axi=3 3.21 3.21 8.99 2.94 4.89 +0% -8% +180%
SumAll / dim=64x128x128x128 3.22 3.21 9.01 3.16 4.53 +0% -2% +179%
SumAll / dim=1000000 0.18 0.18 0.06 0.26 0.08 +1% +41% -65%
SumAll / dim=1000000x128 3.05 3.05 8.64 3.03 4.33 +0% 0% +183%
SumAll / dim=128x1000000 3.04 3.04 8.59 3.01 4.31 +0% -1% +182%

M2 Max (cores: 4E+8P+38GPU) mlx 0.5.0 torch 2.2.1

Operation mlx_gpu mlx_gpu_compile mlx_cpu mps cpu mlx_gpu_compile/mlx_gpu speedup mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 1.51 1.51 10.03 0.87 22.78 0% -42% +566%
Argmax / dim=64x1024x128 axi=1 1.50 1.50 10.12 0.70 6.33 +0% -53% +573%
Argmax / dim=64x1024x128 axi=2 1.50 1.53 10.17 0.67 2.87 -2% -55% +577%
Argmax / dim=64x128x1024 axi=2 1.51 1.51 10.07 0.50 2.54 0% -66% +568%
BCE / dim=1000000 dim=1000000 0.24 0.24 8.06 0.36 1.91 0% +49% +3260%
BCE / dim=100000x32 dim=100000x32 0.40 0.23 26.57 0.54 3.81 +75% +33% +6499%
BCE / dim=100000x64x2 dim=100000x64x2 1.68 0.65 102.48 0.75 15.20 +157% -55% +5992%
BCE / dim=128x100000 dim=128x100000 1.68 0.65 102.54 0.74 15.13 +160% -55% +5993%
Concat / dim=1000000x64 dim=1000000x32 axi=1 2.47 2.23 59.85 2.39 19.93 +10% -3% +2321%
Concat / dim=1000000x64 dim=1000000x128 axi=1 4.33 4.33 138.19 4.59 44.71 0% +6% +3094%
Concat / dim=1000000x64 dim=1000000x64 axi=0 2.99 3.07 56.44 3.27 21.91 -2% +9% +1790%
Concat / dim=64x1000000 dim=64x1000000 axi=0 2.96 3.00 77.67 3.12 22.41 -1% +5% +2528%
Conv1d / dim=100x256x3 dim=8x3x3 0.36 0.25 0.34 0.29 2.95 +42% -18% -3%
Conv1d / dim=100x256x256 dim=8x3x256 1.36 1.32 7.66 0.78 83.60 +2% -42% +463%
Conv1d / dim=16x1000x80 dim=128x11x80 1.26 1.19 3.61 0.78 488.82 +6% -38% +185%
Conv1d / dim=16x1000x3 dim=128x11x3 0.47 0.27 0.49 0.28 66.21 +74% -40% +4%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 3.34 3.32 927.83 2.40 123.16 +0% -28% +27652%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 6.89 7.18 406.23 1.21 19.08 -3% -82% +5792%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 0.53 0.57 677.14 1.14 19.61 -8% +117% +128850%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 0.21 0.24 10.24 0.28 1.27 -14% +35% +4879%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 1.27 0.90 101.18 0.70 8.39 +40% -44% +7888%
Gather / dim=64x256 dim=10 0.17 0.17 0.01 0.19 0.00 -3% +13% -91%
Gather / dim=64x256 dim=1000 0.21 0.26 0.03 0.16 0.06 -18% -23% -83%
Gather / dim=64x256 dim=1000000 6.45 7.37 18.01 42.44 46.71 -12% +558% +179%
Gather / dim=1024x32 dim=10 0.13 0.19 0.01 0.35 0.00 -28% +158% -90%
Gather / dim=1024x32 dim=1000 0.14 0.21 0.02 0.13 0.04 -35% -9% -87%
Gather / dim=1024x32 dim=1000000 0.97 1.20 5.46 5.48 7.06 -18% +462% +460%
LeakyReLU / dim=128x16x1024 0.16 0.22 0.29 0.20 0.70 -27% +28% +87%
LeakyReLU / dim=64x128x1024 0.29 0.38 1.14 0.41 1.73 -22% +38% +289%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 3.80 4.12 15.30 2.60 31.46 -7% -31% +302%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 3.92 4.12 16.73 4.73 34.36 -4% +20% +326%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 7.32 7.37 32.59 19.37 51.09 0% +164% +345%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 12.00 12.01 52.74 37.70 72.46 0% +214% +339%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.49 0.51 0.25 0.46 0.22 -2% -6% -49%
MatMul / dim=32x1x1000 dim=32x1000x128 0.15 0.16 0.14 0.41 0.22 -4% +163% -12%
MatMul / dim=1000x64x256 dim=256x32 0.48 0.43 1.60 0.79 2.04 +12% +63% +232%
MatMul / dim=1000x64x1024 dim=1000x1024x32 1.36 1.37 16.35 6.74 17.97 -1% +395% +1102%
MatMul / dim=1000x1024x64 dim=1000x64x256 4.24 4.32 33.97 4.50 126.12 -1% +6% +701%
MatMul / dim=64x1000000 dim=1000000x32 2.68 2.73 9.67 10.16 10.01 -1% +278% +260%
MatMul / dim=1000000x64 dim=64x1024 13.69 13.94 102.82 36.07 344.95 -1% +163% +651%
PReLU / dim=128x16x1024 dim=1 0.25 0.64 1.32 0.37 0.46 -61% +50% +437%
PReLU / dim=64x128x1024 dim=1 0.32 0.33 5.21 0.48 1.61 -3% +51% +1551%
ReLU / dim=128x16x1024 0.45 0.18 0.20 0.25 0.39 +152% -44% -54%
ReLU / dim=64x128x1024 0.29 0.31 1.03 0.45 1.49 -6% +53% +256%
Scatter / dim=64x16 dim=10 0.12 0.14 0.01 0.17 0.00 -12% +39% -89%
Scatter / dim=64x16 dim=1000 0.14 0.15 0.07 0.12 0.03 -5% -16% -48%
Scatter / dim=64x16 dim=1000000 0.30 0.39 60.03 2.75 4.59 -22% +807% +19730%
Scatter / dim=1024x32 dim=10 0.14 0.14 0.02 0.12 0.00 -4% -9% -88%
Scatter / dim=1024x32 dim=1000 0.14 0.13 0.13 0.13 0.05 +5% -9% -8%
Scatter / dim=1024x32 dim=1000000 0.48 0.51 113.00 5.55 6.26 -6% +1057% +23455%
ScatterSum / dim=64x16 dim=10 0.04 0.03 0.01 nan 0.00 +25% nan% -69%
ScatterSum / dim=64x16 dim=1000 0.03 0.03 0.01 nan 0.00 +1% nan% -70%
ScatterSum / dim=64x16 dim=1000000 0.03 0.03 0.01 nan 1.52 +6% nan% -68%
ScatterSum / dim=1024x32 dim=10 0.03 0.03 0.01 nan 0.01 +16% nan% -72%
ScatterSum / dim=1024x32 dim=1000 0.03 0.03 0.01 nan 0.01 +2% nan% -67%
ScatterSum / dim=1024x32 dim=1000000 0.03 0.03 0.01 nan 6.66 +12% nan% -69%
ScatterMax / dim=64x16 dim=10 0.03 0.03 0.01 nan 0.00 +7% nan% -67%
ScatterMax / dim=64x16 dim=1000 0.03 0.03 0.01 nan 0.00 +11% nan% -70%
ScatterMax / dim=64x16 dim=1000000 0.03 0.03 0.01 nan 1.53 +14% nan% -67%
ScatterMax / dim=1024x32 dim=10 0.03 0.03 0.01 nan 0.01 +10% nan% -67%
ScatterMax / dim=1024x32 dim=1000 0.03 0.03 0.01 nan 0.01 +7% nan% -70%
ScatterMax / dim=1024x32 dim=1000000 0.03 0.03 0.01 nan 6.81 +9% nan% -66%
SeLU / dim=128x16x1024 0.27 0.35 1.88 0.47 3.07 -21% +70% +590%
SeLU / dim=64x128x1024 0.30 0.37 7.36 0.51 11.06 -19% +68% +2344%
Sigmoid / dim=128x16x1024 0.18 0.20 1.73 0.27 3.35 -9% +48% +857%
Sigmoid / dim=64x128x1024 0.30 0.33 6.94 0.46 9.35 -10% +54% +2237%
Softmax / dim=64x1000000 axi=-1 5.61 4.42 48.87 3.16 35.05 +27% -43% +770%
Softmax / dim=1000000x64 axi=-1 5.64 4.39 48.89 4.11 38.25 +28% -27% +766%
Softmax / dim=64x16x32x1024 axi=-1 3.00 2.34 26.03 1.78 19.57 +28% -40% +766%
Softmax / dim=128x16x32x1024 axi=-1 5.90 4.64 52.24 3.56 39.79 +27% -39% +785%
Softmax / dim=1024x16x32x128 axi=-1 5.99 4.58 51.29 4.47 41.25 +30% -25% +755%
Softmax / dim=1024x64x32x8 axi=-1 1.57 1.24 12.89 1.33 25.57 +26% -15% +719%
Softplus / dim=128x16x1024 0.21 0.17 13.89 0.25 4.38 +22% +16% +6474%
Softplus / dim=64x128x1024 0.29 0.31 55.56 0.41 14.40 -7% +41% +18921%
Sort / dim=64x128x1024 axi=0 0.73 0.73 249.05 11.60 63.98 +0% +1494% +34121%
Sort / dim=64x128x1024 axi=1 0.73 0.79 247.89 12.59 52.87 -7% +1625% +33879%
Sort / dim=64x128x1024 axi=2 0.73 0.73 249.72 7.75 59.78 0% +961% +34078%
Sum / dim=64x128x128x128 axi=0 1.64 1.65 9.97 1.69 19.74 0% +2% +506%
Sum / dim=64x128x128x128 axi=1 1.60 1.65 9.08 1.65 15.95 -2% +3% +467%
Sum / dim=64x128x128x128 axi=2 1.60 1.63 8.89 1.67 7.12 -1% +3% +454%
Sum / dim=64x128x128x128 axi=3 1.59 1.64 8.92 2.82 5.40 -2% +76% +459%
SumAll / dim=64x128x128x128 1.61 1.65 9.36 1.70 5.19 -2% +5% +483%
SumAll / dim=1000000 0.15 0.15 0.06 0.21 0.06 -2% +41% -60%
SumAll / dim=1000000x128 1.52 1.56 9.23 1.66 4.92 -2% +9% +507%
SumAll / dim=128x1000000 1.53 1.56 8.81 1.71 5.19 -1% +12% +476%

M2 Ultra (cores: 8E+16P+76GPU) mlx 0.7.0

Operation mlx_gpu mlx_gpu_compile mlx_cpu mps cpu mlx_gpu_compile/mlx_gpu speedup mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 1.80 1.91 9.44 0.96 33.74 -5% -46% +423%
Argmax / dim=64x1024x128 axi=1 1.52 1.53 9.49 0.57 2.44 0% -62% +522%
Argmax / dim=64x1024x128 axi=2 1.53 1.54 9.41 0.57 0.93 0% -62% +515%
Argmax / dim=64x128x1024 axi=2 1.53 1.53 9.50 0.48 0.86 +0% -68% +520%
BCE / dim=1000000 dim=1000000 0.27 0.64 7.46 0.24 1.17 -57% -12% +2659%
BCE / dim=100000x32 dim=100000x32 0.39 0.26 24.54 0.28 2.27 +52% -30% +6125%
BCE / dim=100000x64x2 dim=100000x64x2 0.96 0.46 97.27 0.67 6.88 +108% -30% +9999%
BCE / dim=128x100000 dim=128x100000 0.94 0.45 97.01 0.69 6.62 +110% -26% +10187%
Concat / dim=1000000x64 dim=1000000x32 axi=1 1.25 1.23 61.82 1.27 30.59 +1% +1% +4860%
Concat / dim=1000000x64 dim=1000000x128 axi=1 2.31 2.31 139.94 2.26 46.21 +0% -2% +5954%
Concat / dim=1000000x64 dim=1000000x64 axi=0 1.61 1.62 52.87 1.55 39.65 0% -4% +3179%
Concat / dim=64x1000000 dim=64x1000000 axi=0 1.60 1.60 73.17 1.57 39.25 0% -1% +4478%
Conv1d / dim=100x256x3 dim=8x3x3 0.24 0.25 0.39 0.22 3.75 0% -8% +58%
Conv1d / dim=100x256x256 dim=8x3x256 0.80 0.78 7.00 0.56 90.56 +1% -30% +778%
Conv1d / dim=16x1000x80 dim=128x11x80 0.75 0.74 2.67 0.69 577.37 +1% -8% +256%
Conv1d / dim=16x1000x3 dim=128x11x3 0.40 0.27 0.50 0.34 79.98 +48% -14% +26%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 1.78 1.79 901.91 1.35 177.78 0% -24% +50502%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 3.73 3.75 393.67 0.71 15.00 0% -80% +10449%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 0.40 0.41 647.30 0.62 26.14 -1% +55% +161627%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 0.21 0.22 9.55 0.27 2.02 -2% +27% +4404%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 0.62 0.71 96.47 0.38 9.29 -13% -38% +15491%
Gather / dim=64x256 dim=10 0.17 0.21 0.01 0.12 0.00 -22% -27% -94%
Gather / dim=64x256 dim=1000 0.19 0.20 0.03 0.17 0.12 -5% -7% -83%
Gather / dim=64x256 dim=1000000 3.30 3.36 17.51 20.31 60.94 -1% +515% +430%
Gather / dim=1024x32 dim=10 0.19 0.17 0.01 0.15 0.00 +10% -21% -96%
Gather / dim=1024x32 dim=1000 0.17 0.18 0.02 0.15 0.14 -1% -13% -90%
Gather / dim=1024x32 dim=1000000 0.61 0.61 5.39 2.60 9.69 0% +325% +783%
LeakyReLU / dim=128x16x1024 0.37 0.23 0.62 0.17 0.80 +62% -54% +69%
LeakyReLU / dim=64x128x1024 0.28 0.28 1.08 0.26 3.18 +1% -8% +284%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 0.78 0.77 13.19 1.63 40.32 +0% +108% +1590%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 1.38 1.26 14.62 2.58 41.54 +9% +86% +957%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 2.98 2.99 23.04 10.22 50.67 0% +242% +672%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 5.56 5.54 33.00 18.54 62.85 +0% +233% +493%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.58 0.59 0.28 0.40 0.23 -1% -31% -50%
MatMul / dim=32x1x1000 dim=32x1000x128 0.19 0.20 0.15 0.44 0.21 -3% +127% -21%
MatMul / dim=1000x64x256 dim=256x32 0.36 0.35 0.81 0.65 1.12 +2% +81% +126%
MatMul / dim=1000x64x1024 dim=1000x1024x32 0.86 0.87 17.11 4.34 19.07 -2% +407% +1900%
MatMul / dim=1000x1024x64 dim=1000x64x256 3.35 3.30 33.82 3.32 120.11 +1% 0% +909%
MatMul / dim=64x1000000 dim=1000000x32 2.90 2.90 9.65 7.94 9.58 +0% +173% +232%
MatMul / dim=1000000x64 dim=64x1024 7.51 7.55 53.73 16.86 249.21 0% +124% +615%
PReLU / dim=128x16x1024 dim=1 0.45 0.63 1.49 0.34 0.65 -28% -25% +231%
PReLU / dim=64x128x1024 dim=1 0.28 0.27 4.81 0.30 2.56 +4% +6% +1602%
ReLU / dim=128x16x1024 0.28 0.20 0.33 0.30 0.68 +35% +9% +18%
ReLU / dim=64x128x1024 0.30 0.27 1.02 0.35 2.54 +9% +17% +237%
Scatter / dim=64x16 dim=10 0.17 0.16 0.01 0.14 0.00 +5% -15% -92%
Scatter / dim=64x16 dim=1000 0.17 0.15 0.07 0.13 0.12 +13% -20% -56%
Scatter / dim=64x16 dim=1000000 0.38 0.40 55.76 1.26 4.22 -3% +228% +14467%
Scatter / dim=1024x32 dim=10 0.17 0.19 0.01 0.15 0.00 -8% -11% -92%
Scatter / dim=1024x32 dim=1000 0.18 0.17 0.12 0.15 0.08 +8% -14% -31%
Scatter / dim=1024x32 dim=1000000 0.42 0.42 106.24 2.54 4.48 0% +502% +25148%
ScatterSum / dim=64x16 dim=10 0.03 0.04 0.01 nan 0.00 -6% nan% -74%
ScatterSum / dim=64x16 dim=1000 0.03 0.03 0.01 nan 0.00 -3% nan% -75%
ScatterSum / dim=64x16 dim=1000000 0.03 0.03 0.01 nan 1.47 -6% nan% -73%
ScatterSum / dim=1024x32 dim=10 0.03 0.04 0.01 nan 0.01 -10% nan% -77%
ScatterSum / dim=1024x32 dim=1000 0.03 0.03 0.01 nan 0.01 -1% nan% -75%
ScatterSum / dim=1024x32 dim=1000000 0.04 0.03 0.01 nan 6.69 +23% nan% -79%
ScatterMax / dim=64x16 dim=10 0.03 0.03 0.01 nan 0.00 +28% nan% -78%
ScatterMax / dim=64x16 dim=1000 0.04 0.04 0.01 nan 0.00 +0% nan% -79%
ScatterMax / dim=64x16 dim=1000000 0.03 0.03 0.01 nan 1.52 +2% nan% -72%
ScatterMax / dim=1024x32 dim=10 0.03 0.02 0.01 nan 0.01 +20% nan% -72%
ScatterMax / dim=1024x32 dim=1000 0.04 0.03 0.01 nan 0.01 +17% nan% -75%
ScatterMax / dim=1024x32 dim=1000000 0.04 0.03 0.01 nan 6.66 +2% nan% -77%
SeLU / dim=128x16x1024 0.61 0.27 2.06 0.22 0.87 +126% -63% +235%
SeLU / dim=64x128x1024 0.31 0.29 6.94 0.36 2.86 +6% +16% +2164%
Sigmoid / dim=128x16x1024 0.20 0.22 1.65 0.26 0.76 -7% +31% +719%
Sigmoid / dim=64x128x1024 0.28 0.28 6.58 0.25 2.66 +0% -10% +2240%
Softmax / dim=64x1000000 axi=-1 2.98 2.27 48.06 1.61 19.73 +31% -45% +1513%
Softmax / dim=1000000x64 axi=-1 3.00 2.25 48.24 1.64 20.34 +33% -45% +1507%
Softmax / dim=64x16x32x1024 axi=-1 1.62 1.25 25.49 0.96 10.49 +29% -41% +1469%
Softmax / dim=128x16x32x1024 axi=-1 3.13 2.40 50.50 1.63 22.12 +30% -48% +1514%
Softmax / dim=1024x16x32x128 axi=-1 3.18 2.37 50.50 1.67 22.05 +34% -47% +1487%
Softmax / dim=1024x64x32x8 axi=-1 0.92 0.76 12.85 0.58 12.66 +21% -37% +1291%
Softplus / dim=128x16x1024 0.27 0.23 13.00 0.24 1.57 +20% -11% +4661%
Softplus / dim=64x128x1024 0.28 0.28 51.26 0.27 5.49 -3% 0% +18450%
Sort / dim=64x128x1024 axi=0 0.48 0.52 229.34 6.94 39.20 -6% +1333% +47258%
Sort / dim=64x128x1024 axi=1 0.48 0.47 230.20 7.74 28.25 +2% +1500% +47473%
Sort / dim=64x128x1024 axi=2 0.48 0.48 229.98 4.54 34.28 0% +853% +48194%
Sum / dim=64x128x128x128 axi=0 0.90 0.93 9.23 0.94 8.97 -2% +4% +922%
Sum / dim=64x128x128x128 axi=1 0.87 0.90 9.26 0.94 9.26 -2% +7% +960%
Sum / dim=64x128x128x128 axi=2 0.89 0.90 9.20 0.95 5.67 -1% +7% +936%
Sum / dim=64x128x128x128 axi=3 0.94 0.93 9.20 0.98 3.32 +0% +5% +883%
SumAll / dim=64x128x128x128 0.88 0.89 9.22 0.96 2.67 0% +9% +945%
SumAll / dim=1000000 0.19 0.18 0.06 0.46 0.09 +4% +144% -68%
SumAll / dim=1000000x128 0.87 0.88 8.74 1.01 2.57 -1% +15% +907%
SumAll / dim=128x1000000 0.86 0.87 8.78 0.91 2.56 -2% +6% +926%

M3 (RAM: 16GB) - mlx 0.2.0

Operation mlx_gpu mlx_cpu mps cpu mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 1.64 21.87 2.08 19.89 +26% +1233%
Argmax / dim=64x1024x128 axi=1 1.03 9.30 1.62 3.32 +58% +806%
Argmax / dim=64x1024x128 axi=2 1.08 6.99 1.94 2.77 +80% +549%
Argmax / dim=64x128x1024 axi=2 1.05 8.34 1.21 2.41 +15% +696%
BCE / dim=1000000 dim=1000000 0.80 5.31 1.30 1.30 +63% +566%
BCE / dim=100000x32 dim=100000x32 1.91 18.11 3.64 3.50 +90% +846%
BCE / dim=100000x64x2 dim=100000x64x2 6.58 69.43 14.49 14.17 +120% +954%
BCE / dim=128x100000 dim=128x100000 6.92 70.35 14.92 13.61 +115% +917%
Concat / dim=1000000x64 dim=1000000x32 axi=1 8.76 54.43 8.83 28.59 +0% +521%
Concat / dim=1000000x64 dim=1000000x128 axi=1 18.46 126.17 18.08 58.87 -2% +583%
Concat / dim=1000000x64 dim=1000000x64 axi=0 11.48 65.83 11.70 26.79 +1% +473%
Concat / dim=64x1000000 dim=64x1000000 axi=0 11.39 86.71 11.78 26.93 +3% +661%
Conv1d / dim=100x256x3 dim=8x3x3 0.51 0.33 0.44 1.90 -13% -36%
Conv1d / dim=100x256x256 dim=8x3x256 4.31 8.12 2.40 41.79 -44% +88%
Conv1d / dim=16x1000x80 dim=128x11x80 3.63 5.64 3.77 205.70 +3% +55%
Conv1d / dim=16x1000x3 dim=128x11x3 0.91 0.58 1.31 35.53 +43% -36%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 41.44 730.75 7.45 98.91 -82% +1663%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 16.01 318.25 4.73 31.72 -70% +1888%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 18.47 551.98 6.10 42.98 -66% +2888%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 1.57 9.14 0.93 1.13 -40% +480%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 4.83 90.03 2.93 7.04 -39% +1763%
LeakyReLU / dim=128x16x1024 0.77 1.15 0.68 0.76 -11% +49%
LeakyReLU / dim=64x128x1024 2.09 4.96 1.34 1.38 -35% +137%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 15.25 63.82 7.10 66.79 -53% +318%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 15.55 68.92 7.98 78.16 -48% +343%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 28.76 96.20 21.89 202.41 -23% +234%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 47.50 130.18 41.22 260.27 -13% +174%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.67 0.34 1.02 4.00 +52% -48%
MatMul / dim=32x1x1000 dim=32x1000x128 0.47 0.15 0.97 0.68 +107% -68%
MatMul / dim=1000x64x256 dim=256x32 1.23 3.40 5.30 14.38 +330% +176%
MatMul / dim=1000x64x1024 dim=1000x1024x32 4.76 11.32 4.96 158.75 +4% +137%
MatMul / dim=1000x1024x64 dim=1000x64x256 17.41 89.14 19.42 1214.55 +11% +412%
MatMul / dim=64x1000000 dim=1000000x32 9.07 12.25 7.74 40.94 -14% +35%
MatMul / dim=1000000x64 dim=64x1024 60.00 343.14 161.03 1513.56 +168% +471%
PReLU / dim=128x16x1024 dim=1 1.14 1.10 0.67 0.86 -41% -3%
PReLU / dim=64x128x1024 dim=1 3.58 4.41 1.31 1.36 -63% +23%
ReLU / dim=128x16x1024 0.50 0.33 0.62 0.64 +24% -33%
ReLU / dim=64x128x1024 1.03 2.44 1.31 1.39 +27% +137%
SeLU / dim=128x16x1024 2.28 2.92 0.61 2.87 -73% +27%
SeLU / dim=64x128x1024 8.17 12.51 1.44 10.90 -82% +53%
Sigmoid / dim=128x16x1024 0.53 10.64 0.69 2.66 +30% +1901%
Sigmoid / dim=64x128x1024 1.05 43.31 1.45 8.72 +38% +4022%
Softmax / dim=64x1000000 axi=-1 8.85 42.85 11.65 37.39 +31% +384%
Softmax / dim=1000000x64 axi=-1 6.04 38.42 12.63 40.49 +109% +535%
Softmax / dim=64x16x32x1024 axi=-1 3.29 20.42 10.26 18.81 +212% +521%
Softmax / dim=128x16x32x1024 axi=-1 6.32 41.12 20.14 37.59 +218% +550%
Softmax / dim=1024x16x32x128 axi=-1 6.34 38.48 13.63 39.88 +114% +506%
Softmax / dim=1024x64x32x8 axi=-1 7.02 66.84 4.17 21.11 -40% +852%
Softplus / dim=128x16x1024 0.43 10.47 0.66 3.93 +50% +2309%
Softplus / dim=64x128x1024 1.03 43.17 1.50 14.26 +45% +4101%
Sort / dim=64x128x1024 axi=0 23.39 1015.93 35.02 59.19 +49% +4243%
Sort / dim=64x128x1024 axi=1 12.35 926.84 32.10 56.93 +159% +7405%
Sort / dim=64x128x1024 axi=2 2.27 230.00 25.06 64.75 +1003% +10025%
Sum / dim=64x128x128x128 axi=0 6.42 12.28 6.40 19.30 0% +91%
Sum / dim=64x128x128x128 axi=1 6.25 11.50 6.39 15.18 +2% +84%
Sum / dim=64x128x128x128 axi=2 6.28 11.08 6.34 8.06 +1% +76%
Sum / dim=64x128x128x128 axi=3 8.91 10.09 7.30 6.55 -18% +13%
SumAll / dim=64x128x128x128 5.76 10.70 6.38 5.92 +10% +85%
SumAll / dim=1000000 0.29 0.05 0.39 0.07 +36% -82%
SumAll / dim=1000000x128 5.52 9.80 6.18 5.81 +12% +77%
SumAll / dim=128x1000000 5.48 10.61 6.16 5.73 +12% +93%

M3 Pro (cores: 6E+5P+14GPU)

Operation mlx_gpu mlx_cpu mps cpu mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 1.18 21.22 1.49 19.26 +25% +1695%
Argmax / dim=64x1024x128 axi=1 0.89 8.85 1.22 2.35 +36% +890%
Argmax / dim=64x1024x128 axi=2 0.97 6.65 1.16 1.58 +19% +587%
Argmax / dim=64x128x1024 axi=2 0.89 8.12 1.07 1.37 +21% +816%
BCE / dim=1000000 dim=1000000 0.81 5.96 1.09 1.04 +34% +634%
BCE / dim=100000x32 dim=100000x32 1.76 17.78 3.01 2.86 +71% +912%
BCE / dim=100000x64x2 dim=100000x64x2 4.20 67.99 11.80 11.81 +180% +1517%
BCE / dim=128x100000 dim=128x100000 4.04 67.80 11.57 11.39 +186% +1576%
Concat / dim=1000000x64 dim=1000000x32 axi=1 5.93 51.84 6.50 29.83 +9% +773%
Concat / dim=1000000x64 dim=1000000x128 axi=1 11.57 118.19 11.91 46.27 +3% +921%
Concat / dim=1000000x64 dim=1000000x64 axi=0 7.75 62.53 8.40 37.28 +8% +706%
Concat / dim=64x1000000 dim=64x1000000 axi=0 7.74 80.78 8.66 39.03 +11% +943%
Conv1d / dim=100x256x3 dim=8x3x3 0.80 0.28 0.80 1.97 0% -64%
Conv1d / dim=100x256x256 dim=8x3x256 2.95 7.32 1.99 50.02 -32% +147%
Conv1d / dim=16x1000x80 dim=128x11x80 3.70 5.40 4.23 233.14 +14% +45%
Conv1d / dim=16x1000x3 dim=128x11x3 1.15 0.44 1.26 47.61 +9% -61%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 28.54 738.77 5.85 107.76 -79% +2488%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 15.23 316.15 3.41 22.42 -77% +1975%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 11.07 512.34 3.73 30.77 -66% +4528%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 1.40 8.92 0.93 1.04 -33% +536%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 4.07 88.95 1.54 7.37 -62% +2086%
LeakyReLU / dim=128x16x1024 1.31 0.96 0.99 0.90 -24% -26%
LeakyReLU / dim=64x128x1024 1.78 2.10 1.54 1.02 -13% +18%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 11.22 42.51 5.76 38.86 -48% +278%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 10.79 46.87 5.71 51.97 -47% +334%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 20.39 70.91 15.84 125.66 -22% +247%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 33.41 103.37 28.62 240.45 -14% +209%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.69 0.23 1.28 0.50 +86% -67%
MatMul / dim=32x1x1000 dim=32x1000x128 0.87 0.10 1.18 0.66 +36% -88%
MatMul / dim=1000x64x256 dim=256x32 0.83 2.81 3.93 16.03 +373% +238%
MatMul / dim=1000x64x1024 dim=1000x1024x32 3.25 10.29 3.55 272.27 +9% +216%
MatMul / dim=1000x1024x64 dim=1000x64x256 11.40 82.00 13.39 680.84 +17% +619%
MatMul / dim=64x1000000 dim=1000000x32 39.89 10.89 6.81 155.98 -82% -72%
MatMul / dim=1000000x64 dim=64x1024 40.00 309.51 106.34 1214.46 +165% +673%
PReLU / dim=128x16x1024 dim=1 1.47 1.98 1.09 0.79 -25% +34%
PReLU / dim=64x128x1024 dim=1 2.62 3.62 1.60 1.04 -38% +38%
ReLU / dim=128x16x1024 0.82 0.45 1.19 0.81 +43% -45%
ReLU / dim=64x128x1024 1.06 0.78 1.55 1.03 +46% -26%
SeLU / dim=128x16x1024 2.42 4.90 1.09 2.06 -55% +101%
SeLU / dim=64x128x1024 5.53 15.30 1.45 7.33 -73% +176%
Sigmoid / dim=128x16x1024 0.99 10.68 0.97 1.83 -2% +979%
Sigmoid / dim=64x128x1024 1.07 41.87 1.62 6.72 +51% +3800%
Softmax / dim=64x1000000 axi=-1 6.00 29.23 8.74 34.32 +45% +387%
Softmax / dim=1000000x64 axi=-1 4.05 28.19 9.60 37.89 +136% +595%
Softmax / dim=64x16x32x1024 axi=-1 2.27 14.97 8.09 14.76 +256% +560%
Softmax / dim=128x16x32x1024 axi=-1 4.25 29.98 16.10 35.51 +278% +605%
Softmax / dim=1024x16x32x128 axi=-1 4.23 28.42 10.15 36.84 +139% +571%
Softmax / dim=1024x64x32x8 axi=-1 6.93 64.48 3.24 19.37 -53% +830%
Softplus / dim=128x16x1024 0.71 10.54 1.05 2.72 +47% +1383%
Softplus / dim=64x128x1024 1.33 41.36 1.47 10.32 +10% +3012%
Sort / dim=64x128x1024 axi=0 15.52 1012.11 24.23 46.26 +56% +6422%
Sort / dim=64x128x1024 axi=1 8.50 898.97 22.34 46.41 +162% +10479%
Sort / dim=64x128x1024 axi=2 1.99 224.86 17.53 47.45 +782% +11220%
Sum / dim=64x128x128x128 axi=0 4.29 11.99 5.00 12.36 +16% +179%
Sum / dim=64x128x128x128 axi=1 4.17 10.66 4.98 11.53 +19% +155%
Sum / dim=64x128x128x128 axi=2 4.14 9.48 4.87 6.34 +17% +129%
Sum / dim=64x128x128x128 axi=3 6.34 7.11 5.61 5.09 -11% +12%
SumAll / dim=64x128x128x128 4.12 6.45 4.92 4.72 +19% +56%
SumAll / dim=1000000 0.67 0.06 0.75 0.08 +11% -90%
SumAll / dim=1000000x128 3.93 6.18 4.70 4.46 +19% +57%
SumAll / dim=128x1000000 3.95 6.16 4.37 4.48 +10% +56%

M3 Max (cores: 4E+12P+40GPU) mlx 0.2.0

Operation mlx_gpu mlx_gpu_compile mlx_cpu mps cpu mlx_gpu_compile/mlx_gpu speedup mlx_gpu/mps speedup mlx_gpu/mlx_cpu speedup
Argmax / dim=64x1024x128 axi=0 1.56 1.56 8.35 1.47 20.73 +0% -5% +435%
Argmax / dim=64x1024x128 axi=1 1.57 1.55 8.33 0.98 1.67 +1% -37% +430%
Argmax / dim=64x1024x128 axi=2 1.59 1.56 8.33 0.89 1.16 +1% -43% +424%
Argmax / dim=64x128x1024 axi=2 1.57 1.56 8.34 0.73 1.01 +0% -53% +432%
BCE / dim=1000000 dim=1000000 0.37 0.24 4.97 0.33 0.72 +49% -11% +1258%
BCE / dim=100000x32 dim=100000x32 0.51 0.27 16.26 0.44 1.64 +84% -13% +3110%
BCE / dim=100000x64x2 dim=100000x64x2 1.80 0.79 66.58 0.91 6.25 +128% -49% +3597%
BCE / dim=128x100000 dim=128x100000 1.80 0.78 67.05 0.68 6.29 +130% -61% +3624%
Concat / dim=1000000x64 dim=1000000x32 axi=1 2.43 2.41 66.73 2.50 16.76 +0% +2% +2645%
Concat / dim=1000000x64 dim=1000000x128 axi=1 4.58 4.57 146.32 4.68 36.45 +0% +2% +3094%
Concat / dim=1000000x64 dim=1000000x64 axi=0 3.12 3.11 47.47 3.20 19.43 +0% +2% +1419%
Concat / dim=64x1000000 dim=64x1000000 axi=0 3.13 3.12 68.53 3.20 18.91 +0% +2% +2090%
Conv1d / dim=100x256x3 dim=8x3x3 0.33 0.33 0.29 0.43 2.31 -1% +32% -11%
Conv1d / dim=100x256x256 dim=8x3x256 1.21 1.19 5.99 1.27 68.57 +2% +4% +394%
Conv1d / dim=16x1000x80 dim=128x11x80 1.41 1.03 2.88 1.49 502.60 +37% +5% +104%
Conv1d / dim=16x1000x3 dim=128x11x3 0.44 0.45 0.42 0.51 50.53 0% +13% -4%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 10.22 10.13 722.62 2.18 109.15 +0% -78% +6971%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 3.99 3.96 313.85 1.52 11.02 +0% -61% +7770%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 4.57 4.63 512.73 1.87 27.17 -1% -59% +11117%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 0.62 0.60 9.17 0.50 1.35 +3% -19% +1373%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 1.66 1.37 89.00 1.05 7.55 +20% -36% +5269%
Gather / dim=64x256 dim=10 0.33 0.21 0.01 0.58 0.00 +59% +74% -96%
Gather / dim=64x256 dim=1000 0.27 0.27 0.03 0.59 0.12 0% +115% -90%
Gather / dim=64x256 dim=1000000 6.92 6.79 20.55 41.59 38.15 +1% +501% +197%
Gather / dim=1024x32 dim=10 0.24 0.22 0.01 0.56 0.00 +4% +139% -95%
Gather / dim=1024x32 dim=1000 0.23 0.22 0.01 0.58 0.08 +2% +152% -93%
Gather / dim=1024x32 dim=1000000 1.36 1.12 5.58 5.47 1.70 +22% +301% +309%
LeakyReLU / dim=128x16x1024 0.29 0.25 0.96 0.35 0.89 +17% +19% +230%
LeakyReLU / dim=64x128x1024 0.58 0.34 4.18 0.74 0.44 +70% +27% +623%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 4.06 4.03 19.47 1.91 26.53 +0% -52% +380%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 4.24 4.19 23.09 2.37 55.05 +0% -44% +445%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 7.47 7.45 32.52 5.70 87.50 +0% -23% +335%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 12.00 11.97 48.07 10.49 125.36 +0% -12% +300%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.52 0.63 0.21 0.75 0.77 -16% +42% -60%
MatMul / dim=32x1x1000 dim=32x1000x128 0.24 0.20 0.08 0.58 0.65 +15% +142% -64%
MatMul / dim=1000x64x256 dim=256x32 0.44 0.42 1.35 1.59 15.13 +3% +264% +209%
MatMul / dim=1000x64x1024 dim=1000x1024x32 1.36 1.35 9.85 1.50 428.05 +0% +10% +624%
MatMul / dim=1000x1024x64 dim=1000x64x256 4.65 4.63 43.90 5.21 1111.34 +0% +12% +844%
MatMul / dim=64x1000000 dim=1000000x32 2.75 2.76 7.46 3.89 67.98 0% +41% +171%
MatMul / dim=1000000x64 dim=64x1024 15.77 15.79 90.78 32.97 1891.27 0% +109% +475%
PReLU / dim=128x16x1024 dim=1 0.36 0.22 0.69 0.37 0.77 +59% +2% +91%
PReLU / dim=64x128x1024 dim=1 1.04 0.35 3.43 0.62 0.46 +199% -40% +229%
ReLU / dim=128x16x1024 0.32 0.32 0.25 0.50 0.81 0% +58% -21%
ReLU / dim=64x128x1024 0.70 0.38 1.05 1.06 0.43 +84% +51% +50%
Scatter / dim=64x16 dim=10 0.66 0.43 0.01 0.44 0.00 +52% -33% -98%
Scatter / dim=64x16 dim=1000 0.42 0.38 0.07 0.47 0.07 +9% +13% -83%
Scatter / dim=64x16 dim=1000000 4.17 4.10 52.82 2.79 2.28 +1% -32% +1166%
Scatter / dim=1024x32 dim=10 0.25 0.23 0.01 0.44 0.00 +9% +76% -94%
Scatter / dim=1024x32 dim=1000 0.26 0.24 0.12 0.62 0.07 +8% +134% -55%
Scatter / dim=1024x32 dim=1000000 7.98 7.95 99.38 5.20 3.15 +0% -34% +1145%
ScatterSum / dim=64x16 dim=10 0.05 0.03 0.01 nan 0.00 +60% nan% -82%
ScatterSum / dim=64x16 dim=1000 0.04 0.03 0.01 nan 0.01 +48% nan% -80%
ScatterSum / dim=64x16 dim=1000000 0.04 0.03 0.01 nan 1.18 +54% nan% -81%
ScatterSum / dim=1024x32 dim=10 0.04 0.03 0.01 nan 0.01 +50% nan% -80%
ScatterSum / dim=1024x32 dim=1000 0.04 0.03 0.01 nan 0.01 +49% nan% -81%
ScatterSum / dim=1024x32 dim=1000000 0.04 0.03 0.01 nan 6.13 +51% nan% -80%
ScatterMax / dim=64x16 dim=10 0.04 0.03 0.01 nan 0.00 +55% nan% -80%
ScatterMax / dim=64x16 dim=1000 0.04 0.03 0.01 nan 0.00 +55% nan% -82%
ScatterMax / dim=64x16 dim=1000000 0.04 0.03 0.01 nan 1.21 +60% nan% -81%
ScatterMax / dim=1024x32 dim=10 0.04 0.03 0.01 nan 0.01 +56% nan% -82%
ScatterMax / dim=1024x32 dim=1000 0.05 0.03 0.01 nan 0.01 +46% nan% -80%
ScatterMax / dim=1024x32 dim=1000000 0.04 0.03 0.01 nan 6.16 +43% nan% -81%
SeLU / dim=128x16x1024 0.64 0.22 1.86 0.33 1.44 +187% -47% +191%
SeLU / dim=64x128x1024 2.06 0.36 8.42 0.63 4.41 +468% -69% +309%
Sigmoid / dim=128x16x1024 0.25 0.24 10.47 0.35 1.46 +6% +39% +4066%
Sigmoid / dim=64x128x1024 0.35 0.34 42.10 0.63 4.23 +3% +78% +11895%
Softmax / dim=64x1000000 axi=-1 5.78 4.35 43.94 3.26 21.04 +32% -43% +660%
Softmax / dim=1000000x64 axi=-1 5.78 4.36 43.81 4.02 20.18 +32% -30% +657%
Softmax / dim=64x16x32x1024 axi=-1 3.13 2.38 23.21 2.73 7.76 +31% -12% +641%
Softmax / dim=128x16x32x1024 axi=-1 6.05 4.56 46.01 4.37 19.09 +32% -27% +660%
Softmax / dim=1024x16x32x128 axi=-1 6.06 4.56 46.22 4.28 19.98 +33% -29% +662%
Softmax / dim=1024x64x32x8 axi=-1 1.68 1.32 11.56 1.74 10.91 +26% +3% +589%
Softplus / dim=128x16x1024 0.32 0.24 10.35 0.39 1.84 +33% +22% +3156%
Softplus / dim=64x128x1024 0.39 0.34 41.69 0.62 6.16 +13% +60% +10602%
Sort / dim=64x128x1024 axi=0 0.77 0.75 229.46 9.03 35.77 +2% +1074% +29757%
Sort / dim=64x128x1024 axi=1 0.77 0.76 229.35 8.63 33.35 +1% +1015% +29535%
Sort / dim=64x128x1024 axi=2 0.77 0.76 229.35 6.45 28.18 +1% +737% +29646%
Sum / dim=64x128x128x128 axi=0 1.55 1.55 6.54 1.69 9.59 0% +8% +321%
Sum / dim=64x128x128x128 axi=1 1.54 1.52 6.52 1.66 7.84 +1% +8% +324%
Sum / dim=64x128x128x128 axi=2 1.54 1.54 6.53 1.62 5.63 +0% +5% +323%
Sum / dim=64x128x128x128 axi=3 1.55 1.53 6.53 2.62 4.91 +1% +68% +320%
SumAll / dim=64x128x128x128 1.54 1.54 6.52 1.65 4.38 +0% +7% +323%
SumAll / dim=1000000 0.23 0.21 0.05 0.30 0.08 +8% +29% -77%
SumAll / dim=1000000x128 1.50 1.50 6.30 1.67 4.19 +0% +11% +318%
SumAll / dim=128x1000000 1.49 1.49 6.25 1.66 4.22 +0% +10% +318%

CUDA GPUs

Tesla V100 PCIe (32Go / Intel Xeon Gold 5120 14 cores, 28 threads @ 2.2GHz (Skylake), 60Go)

Operation cpu cuda cuda/cpu speedup
Argmax / dim=64x1024x128 axi=0 72.96 0.09 +80554%
Argmax / dim=64x1024x128 axi=1 25.43 0.11 +22457%
Argmax / dim=64x1024x128 axi=2 20.35 0.12 +16705%
Argmax / dim=64x128x1024 axi=2 18.61 0.09 +21652%
BCE / dim=1000000 dim=1000000 26.32 0.07 +38400%
BCE / dim=100000x32 dim=100000x32 83.80 0.11 +74241%
BCE / dim=100000x64x2 dim=100000x64x2 341.08 0.30 +115358%
BCE / dim=128x100000 dim=128x100000 341.55 0.29 +116168%
Concat / dim=1000000x64 dim=1000000x32 axi=1 277.05 1.19 +23209%
Concat / dim=1000000x64 dim=1000000x128 axi=1 571.38 2.44 +23338%
Concat / dim=1000000x64 dim=1000000x64 axi=0 336.66 1.53 +21834%
Concat / dim=64x1000000 dim=64x1000000 axi=0 338.85 1.53 +22065%
Conv1d / dim=100x256x3 dim=8x3x3 0.71 0.08 +744%
Conv1d / dim=100x256x256 dim=8x3x256 37.65 0.65 +5736%
Conv1d / dim=16x1000x80 dim=128x11x80 79.15 0.47 +16703%
Conv1d / dim=16x1000x3 dim=128x11x3 3.32 0.12 +2596%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 150.83 1.99 +7480%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 46.20 0.65 +7000%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 49.60 1.41 +3409%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 1.64 0.06 +2490%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 15.40 0.25 +6068%
Gather / dim=64x256 dim=10 0.02 0.04 -49%
Gather / dim=64x256 dim=1000 0.15 0.04 +293%
Gather / dim=64x256 dim=1000000 538.57 2.08 +25740%
Gather / dim=1024x32 dim=10 0.02 0.04 -43%
Gather / dim=1024x32 dim=1000 0.05 0.04 +37%
Gather / dim=1024x32 dim=1000000 40.86 0.31 +13082%
LeakyReLU / dim=128x16x1024 1.07 0.05 +2185%
LeakyReLU / dim=64x128x1024 9.94 0.11 +9104%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 311.56 1.87 +16520%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 454.32 2.04 +22143%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 1287.08 5.26 +24360%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 2455.45 9.67 +25289%
Linear / dim=100x1x51200 dim=51200x1 dim=1 1.49 0.08 +1728%
MatMul / dim=32x1x1000 dim=32x1000x128 1.18 0.06 +1948%
MatMul / dim=1000x64x256 dim=256x32 28.43 0.39 +7143%
MatMul / dim=1000x64x1024 dim=1000x1024x32 101.83 1.59 +6318%
MatMul / dim=1000x1024x64 dim=1000x64x256 1440.54 2.82 +50943%
MatMul / dim=64x1000000 dim=1000000x32 125.20 0.65 +19267%
MatMul / dim=1000000x64 dim=64x1024 5749.54 11.27 +50919%
PReLU / dim=128x16x1024 dim=1 1.07 0.05 +2223%
PReLU / dim=64x128x1024 dim=1 10.03 0.11 +9285%
ReLU / dim=128x16x1024 1.08 0.05 +2232%
ReLU / dim=64x128x1024 9.91 0.11 +9104%
Scatter / dim=64x16 dim=10 0.02 0.03 -40%
Scatter / dim=64x16 dim=1000 0.05 0.03 +44%
Scatter / dim=64x16 dim=1000000 20.02 0.25 +7856%
Scatter / dim=1024x32 dim=10 0.02 0.03 -40%
Scatter / dim=1024x32 dim=1000 0.05 0.03 +65%
Scatter / dim=1024x32 dim=1000000 21.35 0.32 +6560%
ScatterSum / dim=64x16 dim=10 0.02 0.05 -47%
ScatterSum / dim=64x16 dim=1000 0.03 0.05 -25%
ScatterSum / dim=64x16 dim=1000000 7.61 0.19 +3937%
ScatterSum / dim=1024x32 dim=10 0.03 0.05 -39%
ScatterSum / dim=1024x32 dim=1000 0.04 0.05 -9%
ScatterSum / dim=1024x32 dim=1000000 17.76 0.12 +14453%
ScatterMax / dim=64x16 dim=10 nan nan nan%
ScatterMax / dim=64x16 dim=1000 nan nan nan%
ScatterMax / dim=64x16 dim=1000000 nan nan nan%
ScatterMax / dim=1024x32 dim=10 nan nan nan%
ScatterMax / dim=1024x32 dim=1000 nan nan nan%
ScatterMax / dim=1024x32 dim=1000000 nan nan nan%
SeLU / dim=128x16x1024 3.64 0.05 +7573%
SeLU / dim=64x128x1024 19.47 0.11 +17827%
Sigmoid / dim=128x16x1024 2.75 0.05 +5743%
Sigmoid / dim=64x128x1024 16.17 0.11 +14728%
Softmax / dim=64x1000000 axi=-1 287.67 1.63 +17518%
Softmax / dim=1000000x64 axi=-1 274.14 0.66 +41722%
Softmax / dim=64x16x32x1024 axi=-1 140.28 0.36 +38520%
Softmax / dim=128x16x32x1024 axi=-1 280.85 0.70 +40182%
Softmax / dim=1024x16x32x128 axi=-1 279.57 0.68 +40838%
Softmax / dim=1024x64x32x8 axi=-1 66.09 0.20 +33030%
Softplus / dim=128x16x1024 7.93 0.05 +16119%
Softplus / dim=64x128x1024 36.33 0.11 +32760%
Sort / dim=64x128x1024 axi=0 567.66 3.87 +14573%
Sort / dim=64x128x1024 axi=1 409.30 1.89 +21560%
Sort / dim=64x128x1024 axi=2 602.02 2.02 +29717%
Sum / dim=64x128x128x128 axi=0 88.65 0.71 +12444%
Sum / dim=64x128x128x128 axi=1 85.62 0.68 +12433%
Sum / dim=64x128x128x128 axi=2 50.47 0.70 +7131%
Sum / dim=64x128x128x128 axi=3 44.98 0.73 +6063%
SumAll / dim=64x128x128x128 40.95 0.67 +6043%
SumAll / dim=1000000 0.24 0.04 +516%
SumAll / dim=1000000x128 39.07 0.65 +5878%
SumAll / dim=128x1000000 39.02 0.65 +5859%

Tesla V100 NVLink (32Go / Intel Xeon Gold 6148 20 cores, 40 threads @ 2.4 GHz (Skylake), 60Go)

Operation cpu cuda cuda/cpu speedup
Argmax / dim=64x1024x128 axi=0 57.05 0.09 +64142%
Argmax / dim=64x1024x128 axi=1 23.09 0.11 +21144%
Argmax / dim=64x1024x128 axi=2 17.13 0.11 +14831%
Argmax / dim=64x128x1024 axi=2 15.64 0.08 +18718%
BCE / dim=1000000 dim=1000000 22.45 0.06 +35452%
BCE / dim=100000x32 dim=100000x32 72.05 0.11 +65232%
BCE / dim=100000x64x2 dim=100000x64x2 330.74 0.29 +112871%
BCE / dim=128x100000 dim=128x100000 318.96 0.29 +108848%
Concat / dim=1000000x64 dim=1000000x32 axi=1 364.51 1.18 +30747%
Concat / dim=1000000x64 dim=1000000x128 axi=1 841.81 2.42 +34620%
Concat / dim=1000000x64 dim=1000000x64 axi=0 452.68 1.53 +29421%
Concat / dim=64x1000000 dim=64x1000000 axi=0 466.35 1.53 +30291%
Conv1d / dim=100x256x3 dim=8x3x3 0.52 0.07 +600%
Conv1d / dim=100x256x256 dim=8x3x256 47.17 0.62 +7567%
Conv1d / dim=16x1000x80 dim=128x11x80 39.96 0.45 +8734%
Conv1d / dim=16x1000x3 dim=128x11x3 1.83 0.11 +1520%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 207.97 1.90 +10863%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 21.61 0.62 +3404%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 23.09 1.35 +1608%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 1.15 0.06 +1844%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 10.62 0.24 +4407%
Gather / dim=64x256 dim=10 0.02 0.04 -56%
Gather / dim=64x256 dim=1000 0.14 0.04 +266%
Gather / dim=64x256 dim=1000000 846.18 2.03 +41625%
Gather / dim=1024x32 dim=10 0.02 0.04 -54%
Gather / dim=1024x32 dim=1000 0.05 0.04 +26%
Gather / dim=1024x32 dim=1000000 122.97 0.30 +41252%
LeakyReLU / dim=128x16x1024 1.45 0.05 +2933%
LeakyReLU / dim=64x128x1024 32.44 0.11 +29854%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 546.20 1.83 +29689%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 521.31 1.98 +26244%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 827.56 5.03 +16363%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 1437.48 8.98 +15914%
Linear / dim=100x1x51200 dim=51200x1 dim=1 1.39 0.08 +1690%
MatMul / dim=32x1x1000 dim=32x1000x128 1.07 0.05 +1843%
MatMul / dim=1000x64x256 dim=256x32 19.51 0.38 +5093%
MatMul / dim=1000x64x1024 dim=1000x1024x32 69.19 1.52 +4448%
MatMul / dim=1000x1024x64 dim=1000x64x256 1241.71 2.69 +46020%
MatMul / dim=64x1000000 dim=1000000x32 84.51 0.65 +12958%
MatMul / dim=1000000x64 dim=64x1024 4573.74 10.77 +42376%
PReLU / dim=128x16x1024 dim=1 1.42 0.04 +3123%
PReLU / dim=64x128x1024 dim=1 29.68 0.11 +27829%
ReLU / dim=128x16x1024 1.45 0.04 +3135%
ReLU / dim=64x128x1024 26.68 0.11 +24987%
Scatter / dim=64x16 dim=10 0.01 0.03 -49%
Scatter / dim=64x16 dim=1000 0.03 0.03 +28%
Scatter / dim=64x16 dim=1000000 16.90 0.24 +6954%
Scatter / dim=1024x32 dim=10 0.01 0.03 -49%
Scatter / dim=1024x32 dim=1000 0.04 0.03 +37%
Scatter / dim=1024x32 dim=1000000 20.12 0.31 +6358%
ScatterSum / dim=64x16 dim=10 0.02 0.04 -62%
ScatterSum / dim=64x16 dim=1000 0.02 0.04 -38%
ScatterSum / dim=64x16 dim=1000000 12.03 0.18 +6606%
ScatterSum / dim=1024x32 dim=10 0.02 0.04 -50%
ScatterSum / dim=1024x32 dim=1000 0.03 0.04 -21%
ScatterSum / dim=1024x32 dim=1000000 28.86 0.12 +23973%
ScatterMax / dim=64x16 dim=10 nan nan nan%
ScatterMax / dim=64x16 dim=1000 nan nan nan%
ScatterMax / dim=64x16 dim=1000000 nan nan nan%
ScatterMax / dim=1024x32 dim=10 nan nan nan%
ScatterMax / dim=1024x32 dim=1000 nan nan nan%
ScatterMax / dim=1024x32 dim=1000000 nan nan nan%
SeLU / dim=128x16x1024 3.54 0.05 +7533%
SeLU / dim=64x128x1024 38.35 0.11 +35654%
Sigmoid / dim=128x16x1024 2.63 0.05 +5376%
Sigmoid / dim=64x128x1024 37.01 0.11 +34194%
Softmax / dim=64x1000000 axi=-1 333.85 1.62 +20570%
Softmax / dim=1000000x64 axi=-1 302.71 0.65 +46345%
Softmax / dim=64x16x32x1024 axi=-1 157.12 0.36 +43248%
Softmax / dim=128x16x32x1024 axi=-1 318.22 0.69 +45815%
Softmax / dim=1024x16x32x128 axi=-1 304.80 0.68 +44679%
Softmax / dim=1024x64x32x8 axi=-1 105.89 0.20 +53796%
Softplus / dim=128x16x1024 7.13 0.05 +14967%
Softplus / dim=64x128x1024 51.29 0.11 +46704%
Sort / dim=64x128x1024 axi=0 417.40 3.70 +11178%
Sort / dim=64x128x1024 axi=1 360.67 1.81 +19796%
Sort / dim=64x128x1024 axi=2 490.86 1.94 +25253%
Sum / dim=64x128x128x128 axi=0 92.62 0.70 +13038%
Sum / dim=64x128x128x128 axi=1 85.36 0.68 +12454%
Sum / dim=64x128x128x128 axi=2 49.97 0.69 +7165%
Sum / dim=64x128x128x128 axi=3 49.57 0.71 +6847%
SumAll / dim=64x128x128x128 42.74 0.66 +6368%
SumAll / dim=1000000 0.21 0.03 +528%
SumAll / dim=1000000x128 40.79 0.65 +6190%
SumAll / dim=128x1000000 40.77 0.65 +6189%

RTX4090 ((Desktop) / 10th Gen Intel Core i9-10940X @ 3.30GHz 128GB)

Operation cpu cuda cuda/cpu speedup
Argmax / dim=64x1024x128 axi=0 15.92 0.04 +39326%
Argmax / dim=64x1024x128 axi=1 4.11 0.05 +7998%
Argmax / dim=64x1024x128 axi=2 3.46 0.05 +6615%
Argmax / dim=64x128x1024 axi=2 3.20 0.04 +8608%
BCE / dim=1000000 dim=1000000 3.84 0.05 +7086%
BCE / dim=100000x32 dim=100000x32 10.57 0.05 +19148%
BCE / dim=100000x64x2 dim=100000x64x2 40.02 0.22 +17863%
BCE / dim=128x100000 dim=128x100000 40.52 0.22 +17958%
Concat / dim=1000000x64 dim=1000000x32 axi=1 38.01 0.97 +3814%
Concat / dim=1000000x64 dim=1000000x128 axi=1 86.99 1.79 +4759%
Concat / dim=1000000x64 dim=1000000x64 axi=0 41.79 1.20 +3376%
Concat / dim=64x1000000 dim=64x1000000 axi=0 41.53 1.22 +3317%
Conv1d / dim=100x256x3 dim=8x3x3 0.33 0.07 +359%
Conv1d / dim=100x256x256 dim=8x3x256 5.11 0.30 +1596%
Conv1d / dim=16x1000x80 dim=128x11x80 5.24 0.13 +4017%
Conv1d / dim=16x1000x3 dim=128x11x3 0.69 0.11 +507%
Conv2d / dim=100x256x256x3 dim=8x3x3x3 22.06 0.74 +2862%
Conv2d / dim=10x256x256x12 dim=8x3x3x12 4.15 0.18 +2247%
Conv2d / dim=1x256x256x128 dim=8x3x3x128 3.45 0.15 +2202%
Conv2d / dim=100x28x28x3 dim=8x3x3x3 0.56 0.06 +832%
Conv2d / dim=1000x28x28x3 dim=8x3x3x3 2.79 0.11 +2449%
Gather / dim=64x256 dim=10 0.02 0.03 -33%
Gather / dim=64x256 dim=1000 0.11 0.04 +156%
Gather / dim=64x256 dim=1000000 103.61 1.23 +8337%
Gather / dim=1024x32 dim=10 0.03 0.05 -45%
Gather / dim=1024x32 dim=1000 0.06 0.04 +23%
Gather / dim=1024x32 dim=1000000 14.67 0.19 +7595%
LeakyReLU / dim=128x16x1024 0.43 0.03 +1519%
LeakyReLU / dim=64x128x1024 4.45 0.04 +11604%
Linear / dim=100x1024x32 dim=32x1024 dim=1024 53.35 0.59 +8943%
Linear / dim=100x1024x64 dim=64x1024 dim=1024 56.93 0.70 +8089%
Linear / dim=100x1024x256 dim=256x1024 dim=1024 79.14 1.26 +6166%
Linear / dim=100x1024x512 dim=512x1024 dim=1024 121.64 2.46 +4854%
Linear / dim=100x1x51200 dim=51200x1 dim=1 0.27 0.05 +401%
MatMul / dim=32x1x1000 dim=32x1000x128 0.16 0.05 +251%
MatMul / dim=1000x64x256 dim=256x32 2.53 0.07 +3323%
MatMul / dim=1000x64x1024 dim=1000x1024x32 8.35 0.73 +1051%
MatMul / dim=1000x1024x64 dim=1000x64x256 108.79 1.60 +6689%
MatMul / dim=64x1000000 dim=1000000x32 9.42 0.50 +1791%
MatMul / dim=1000000x64 dim=64x1024 395.58 5.23 +7468%
PReLU / dim=128x16x1024 dim=1 0.39 0.03 +1103%
PReLU / dim=64x128x1024 dim=1 4.18 0.05 +7858%
ReLU / dim=128x16x1024 0.63 0.03 +2102%
ReLU / dim=64x128x1024 4.32 0.04 +10970%
Scatter / dim=64x16 dim=10 0.01 0.02 -54%
Scatter / dim=64x16 dim=1000 0.03 0.02 +30%
Scatter / dim=64x16 dim=1000000 5.03 0.14 +3531%
Scatter / dim=1024x32 dim=10 0.02 0.02 -36%
Scatter / dim=1024x32 dim=1000 0.04 0.02 +47%
Scatter / dim=1024x32 dim=1000000 5.92 0.17 +3423%
ScatterSum / dim=64x16 dim=10 0.02 0.05 -62%
ScatterSum / dim=64x16 dim=1000 0.03 0.04 -37%
ScatterSum / dim=64x16 dim=1000000 6.75 0.11 +5789%
ScatterSum / dim=1024x32 dim=10 0.04 0.04 0%
ScatterSum / dim=1024x32 dim=1000 0.06 0.05 +16%
ScatterSum / dim=1024x32 dim=1000000 16.28 0.09 +17776%
ScatterMax / dim=64x16 dim=10 0.02 0.04 -53%
ScatterMax / dim=64x16 dim=1000 0.02 0.04 -48%
ScatterMax / dim=64x16 dim=1000000 6.60 0.19 +3439%
ScatterMax / dim=1024x32 dim=10 0.04 0.04 +1%
ScatterMax / dim=1024x32 dim=1000 0.04 0.04 -7%
ScatterMax / dim=1024x32 dim=1000000 16.46 0.12 +13167%
SeLU / dim=128x16x1024 0.77 0.04 +1917%
SeLU / dim=64x128x1024 4.65 0.04 +11906%
Sigmoid / dim=128x16x1024 0.67 0.05 +1367%
Sigmoid / dim=64x128x1024 4.60 0.05 +9782%
Softmax / dim=64x1000000 axi=-1 37.95 1.16 +3157%
Softmax / dim=1000000x64 axi=-1 27.42 0.59 +4530%
Softmax / dim=64x16x32x1024 axi=-1 14.56 0.32 +4490%
Softmax / dim=128x16x32x1024 axi=-1 28.07 0.64 +4269%
Softmax / dim=1024x16x32x128 axi=-1 27.67 0.62 +4343%
Softmax / dim=1024x64x32x8 axi=-1 30.83 0.18 +17281%
Softplus / dim=128x16x1024 1.28 0.05 +2674%
Softplus / dim=64x128x1024 5.73 0.04 +14673%
Sort / dim=64x128x1024 axi=0 42.75 1.34 +3095%
Sort / dim=64x128x1024 axi=1 49.05 0.89 +5388%
Sort / dim=64x128x1024 axi=2 48.21 0.47 +10100%
Sum / dim=64x128x128x128 axi=0 15.71 0.62 +2435%
Sum / dim=64x128x128x128 axi=1 13.05 0.62 +1997%
Sum / dim=64x128x128x128 axi=2 10.13 0.63 +1515%
Sum / dim=64x128x128x128 axi=3 9.89 0.61 +1515%
SumAll / dim=64x128x128x128 9.43 0.61 +1453%
SumAll / dim=1000000 0.04 0.03 +30%
SumAll / dim=1000000x128 9.09 0.58 +1460%
SumAll / dim=128x1000000 9.22 0.59 +1450%