# 一、Numba 简介
计算机只能执行二进制的机器码, C
、 C++
等编译型语言依靠编译器将源代码转化为可执行文件后才能运行, Python
、 Java
等解释型语言使用解释器将源代码翻译后在虚拟机上执行。对于 Python
,由于解释器的存在,其执行效率比 C
语言慢几倍甚至几十倍。通过下图对各种编程语言比较情况可见 python
在消耗电量、速度以及内存使用情况上与 C
、 C++
的差距。
Numba
是一个针对 Python
的开源 JIT
编译器,由 Anaconda
公司主导开发,可以对原生代码进行 CPU
和 GPU
加速。在使用 Python
进行高性能计算时, Numba
提供的加速效果可以比肩原生的 C/C++
程序。
# 二、numba 的使用方式
通常,在对 Python
函数应用函数装饰器后,便可启用 Numba
编译器。装饰器即为函数修改器,使用十分简单的语法来转换所装饰的 Python
函数。
# 2.1 @jit
我们只需要在原来的代码上添加一行 @jit
,即可将一个函数编译成机器码,其他地方都不需要更改。 @
符号装饰了原来的代码,所以称类似写法为装饰器。
from numba import jit | |
import numpy as np | |
a = np.arange(10) | |
b = np.arange(1,11) | |
@jit | |
def test(): | |
return a+b | |
test() |
# 2.2 @njit
@jit
装饰器有一个参数 nopython
, 用于区分 numba
的运行模式, numba
有两种运行模式:一个是 nopython
模式,另一个就是 object
模式。只有在 nopython
模式下,才会获得最好的加速效果,如果 numba
发现代码有不能理解的东西,就会自动进入 object
模式,确保程序能够运行的。将装饰器改为 @jit(nopython=True)
或者 @njit
, numba
会强制使用加速的方式,不进入 object
模式,如编译不成功,则直接抛出异常。
from numba import jit | |
import numpy as np | |
a = np.arange(10) | |
b = np.arange(1,11) | |
@njit # 相当于 @jit (nonpyhon=True) | |
def test(): | |
return a+b | |
test() |
# 2.3 @vectorize
通过使用 @vectorize
装饰器,你可以对仅能对标量操作的函数进行转换,例如,如果你使用的是仅适用于标量的 python
的 math
库,则转换后就可以用于数组。这提供了类似于 numpy
数组运算。
from numba import jit, int32 | |
@vectorize | |
def func(a, b): | |
# Some operation on scalars | |
return result |
# 三、使用注意事项
(1) Numba
简单到只需要在函数上加一个装饰就能加速程序,但也有缺点。目前 Numba
只支持了 Python
原生函数和部分 NumPy
函数,其他一些场景可能不适用。
比如类似 pandas
这样的库是更高层次的封装, Numba
其实不能理解它里面做了什么,所以无法对其加速。一些大家经常用的机器学习框架,如 scikit-learn
, tensorflow
, pytorch
等,已经做了大量的优化,不适合再使用 Numba
做加速。
(2)原生 Python
速度慢的另一个重要原因是变量类型不确定。声明一个变量的语法很简单,如 a = 1
,但没有指定 a
到底是一个整数和一个浮点小数。 Python
解释器要进行大量的类型推断,会非常耗时。 引入 Numba
后, Numba
也要推断输入输出的类型,才能转化为机器码。针对这个问题, Numba
给出了名为 Eager Compilation
的优化方式。
from numba import jit, int32 | |
@jit("int32(int32, int32)", nopython=True) | |
def f2(x, y): | |
return x + y |
@jit(int32(int32, int32))
告知 Numba
你的函数在使用什么样的输入和输出,括号内是输入,括号左侧是输出。这样不会加快执行速度,但是会加快编译速度,可以更快将函数编译到机器码上。
# 四、成功使用案例
# 功能:根据输入参数获得 z(货架 i 为订单 o 提供的商品 c 的数量) | |
# 输入参数:各参数同 parameter.py | |
# 输出参数:z | |
@nb.njit( | |
nb.types.Array(nb.int32, 3, "C")(nb.types.Array(nb.int32, 2, "C"), nb.types.Array(nb.int32, 1, "C"), nb.types.Array(nb.int32, 2, "C"), | |
nb.types.Array(nb.int32, 2, "C"))) | |
def get_z(y, selected_pods, A, B): | |
z = np.zeros((y.shape[0], B.shape[0], B.shape[1]), dtype=nb.int32) | |
transform_y = [] | |
for i in range(selected_pods.shape[0]): | |
for j in range(y.shape[0]): | |
if y[j, i] == 1: | |
transform_y.append(j) | |
break | |
remaining_orders = np.copy(B) # 所有订单剩余量(n_order*n_commodity_type) | |
for i in range(selected_pods.shape[0]): # 第 i 个货架 | |
remaining_A = np.copy(A[transform_y[i]]) # 当前货架的剩余量(1*n_commodity_type) | |
for j in range(B.shape[0]): # 第 j 个订单 | |
for k in range(B.shape[1]): # 第 k 种商品 | |
if remaining_A[k] > 0 and remaining_orders[j, k] > 0: | |
if remaining_A[k] >= remaining_orders[j, k]: # 货架 i 足够为订单 k 提供 z 种商品 | |
z[transform_y[i], j, k] = remaining_orders[j, k] | |
remaining_A[k] = remaining_A[k] - remaining_orders[j, k] | |
remaining_orders[j, k] = 0 | |
else: # 货架 i 不足够为订单 k 提供 z 种商品 | |
# z[transform_y[i]][j][k] = remaining_A[k] | |
z[transform_y[i], j, k] = remaining_A[k] | |
remaining_orders[j, k] = remaining_orders[j, k] - remaining_A[k] | |
remaining_A[k] = 0 | |
return z |