🥑原理:数字水印 | 奇异值分解 SVD 的定义、原理及性质
🥑参考:Python 机器学习笔记:奇异值分解(SVD)算法
正文
对于一个图像矩阵,我们总可以将其分解为以下形式:
通过选取不同个数 Σ \Sigma Σ 矩阵中的奇异值,就可以实现图像的压缩。
如果你没有了解过原理,那么你当然看不懂这是什么意思😇
如果想要实现图像的压缩,那么可以先使用 n u m p y \mathsf{numpy} numpy 库中的 linalg.svd
函数对图像矩阵进行分解,然后提取前 k k k 个奇异值以实现 SVD 图像压缩效果。下面让我们看一下代码。
1 核心代码
定义 s v d _ c o m p r e s s i o n \mathsf{svd\_compression} svd_compression 函数:
def svd_compression(img, k):res_image = np.zeros_like(img)for i in range(img.shape[2]):U, Sigma, VT = np.linalg.svd(img[:, :, i])res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])return res_image
参数说明:
- i m g \mathsf{img} img 是待处理的图像
- k \mathsf{k} k 用于设置选定前 k k k 个奇异值
代码说明:
初始化 r e s _ i m a g e \mathsf{res\_image} res_image 变量,用于存放处理结果:
res_image = np.zeros_like(img)
循环压缩每一个通道:
for i in range(img.shape[2]):U, Sigma, VT = np.linalg.svd(img[:, :, i])res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])
- 参数: i m g . s h a p e [ 2 ] \mathsf{img.shape[2]} img.shape[2] 是图像的通道个数
- 第一行:对第 i i i 个通道进行 SVD 分解
- 第二行:取前 k k k 个奇异值重新构造图像
说明:由于 S i g m a \mathsf{Sigma} Sigma 矩阵除对角元素外,其余元素都为 0 \mathsf{0} 0,因此
linalg.svd
函数将其处理为一维矩阵返回。在重新构造图像时,我们需要使用np.diag
函数将其还原为对角矩阵。
2 完整代码
import numpy as np
import cv2
from matplotlib import pyplot as pltimg = cv2.imread('white_bear.jpg')
img = img[:, :, [2, 1, 0]]
print('image shape is ', img.shape)def svd_compression(img, k):res_image = np.zeros_like(img)for i in range(img.shape[2]):U, Sigma, VT = np.linalg.svd(img[:, :, i])res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])return res_image# 保留前 k 个奇异值
res1 = svd_compression(img, k=300)
res2 = svd_compression(img, k=200)
res3 = svd_compression(img, k=100)
res4 = svd_compression(img, k=50)plt.subplot(1, 5, 1)
plt.title("image", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(img, cmap='gray')plt.subplot(1, 5, 2)
plt.title("image", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res1, cmap='gray')plt.subplot(1, 5, 3)
plt.title("u", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res2, cmap='gray')plt.subplot(1, 5, 4)
plt.title("s", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res3, cmap='gray')plt.subplot(1, 5, 5)
plt.title("v", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res4, cmap='gray')plt.show()