python用直方图规定化实现图像风格转换
以下内容需要直方图均衡化、规定化知识
直方图均衡化应用:
图像直方图均衡化能拉伸灰度图,让像素值均匀分布在0,255之间,使图像看起来不会太亮或太暗,常用于图像增强;
直方图规定化应用:
举个例子,当我们需要对多张图像进行拼接时,我们希望这些图片的亮度、饱和度保持一致,事实上就是让它们的直方图分布一致,这时就需要直方图规定化。
直方图规定化与均衡化的思想一致,事实上就是找到各个灰度级别的映射关系。具体实现的过程中一般会选一个参考图像记为A,找到A的直方图与目标图像的直方图的映射关系,从而找到目标图像的像素以A为“参考”时的映射关系。
具体实现可参考文中链接(看完茅塞顿开)
基于python利用直方图规定化统一图像风格
参考图像
原始图像(第一行)/处理后的图像(第二行)
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
源码:
import os
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import time
def get_map(Hist):
"""
计算图像的映射关系。
:param Hist: 图像的直方图
:return: 映射关系数组
"""
# 计算概率分布 Pr,使用向量化操作
Pr = Hist / np.sum(Hist)
# 计算累计概率 Sk,使用 np.cumsum 函数
Sk = np.cumsum(Pr)
# 计算映射关系 img_map,使用向量化操作
img_map = (255 * Sk + 0.5).astype(np.uint8)
return img_map
def get_off_map(map_):
"""
计算反向映射,寻找最小期望。
:param map_: 输入的映射关系数组
:return: 反向映射数组
"""
off_map = np.zeros(256, dtype=np.uint8)
last_valid_index = 0
for i in range(256):
indices = np.where(map_ == i)[0]
if indices.size > 0:
last_valid_index = indices[0]
off_map[i] = last_valid_index
return off_map
def process_channel(channel):
"""
处理单个颜色通道,计算映射和反向映射。
:param channel: 单个颜色通道的图像数据
:return: 反向映射数组
"""
Hist = cv2.calcHist([channel], [0], None, [256], [0, 255])
map_ = get_map(Hist)
off_map = get_off_map(map_)
return off_map
def get_infer_map(infer_img):
"""
计算参考图像的映射关系。
:param infer_img: 参考图像
:return: 包含三个通道反向映射的列表
"""
with ThreadPoolExecutor(max_workers=3) as executor:
results = list(executor.map(process_channel, cv2.split(infer_img)))
return results
def get_finalmap(org_map, infer_off_map):
"""
计算原始图像到最终输出图像的映射关系。
:param org_map: 原始图像的映射关系数组
:param infer_off_map: 参考图像的反向映射数组
:return: 最终映射关系数组
"""
return infer_off_map[org_map.astype(np.uint8)]
def get_newimg(img_org, org2infer_maps):
"""
根据映射关系生成新图像。
:param img_org: 原始图像
:param org2infer_maps: 包含三个通道最终映射关系的列表
:return: 新图像
"""
b, g, r = cv2.split(img_org)
b = org2infer_maps[0][b]
g = org2infer_maps[1][g]
r = org2infer_maps[2][r]
newimg = cv2.merge([b, g, r])
return newimg
def get_new_img(img_org, infer_map):
"""
根据参考映射关系生成新图像。
:param img_org: 原始图像
:param infer_map: 参考图像的映射关系
:return: 新图像
"""
with ThreadPoolExecutor(max_workers=3) as executor:
org_Hists = list(executor.map(
lambda ch: cv2.calcHist([ch], [0], None, [256], [0, 255]),
cv2.split(img_org)
))
org_maps = [get_map(Hist) for Hist in org_Hists]
org2infer_maps = [get_finalmap(org_map, infer_map[i]) for i, org_map in enumerate(org_maps)]
return get_newimg(img_org, org2infer_maps)
def img2better(infer_map, img_path):
img_org = cv2.imread(img_path)
new_img = get_new_img(img_org, infer_map)
return new_img
if __name__ == "__main__":
t0 = time.time()
infer_img_path = 'imgss/5334010103100023.jpg'
infer_img = cv2.imread(infer_img_path)
infer_map = get_infer_map(infer_img) # 计算参考映射关系
img_path = 'imgss_533422/5334220101100008.jpg'
img_org = cv2.imread(img_path)
new_img = get_new_img(img_org, infer_map) # 根据映射关系获得新的图像
cv2.imwrite('33.jpg', new_img)
t1 = time.time()
print(t1 - t0)