热力图
import matplotlib. pyplot as plt
import seaborn as sns
import numpy as npdef create_heatmap ( people, categories, data= None , title= '热力图' , xlabel= '类别' , ylabel= '人员' , value_range= ( 0.6 , 0.95 ) , figsize= ( 10 , 6 ) , cmap= 'YlOrRd' , decimal_places= 3 ) : """创建热力图参数:people: list, 人员名称列表categories: list, 类别名称列表data: numpy.ndarray, 可选,数据矩阵。如果为None,将自动生成随机数据title: str, 图表标题xlabel: str, x轴标签ylabel: str, y轴标签value_range: tuple, 数值范围,用于生成随机数据和设置颜色范围figsize: tuple, 图形大小cmap: str, 颜色方案decimal_places: int, 显示数值的小数位数返回:matplotlib.figure.Figure: 生成的图形对象""" plt. rcParams[ 'font.sans-serif' ] = [ 'SimHei' ] plt. rcParams[ 'axes.unicode_minus' ] = False if data is None : data = np. random. uniform( value_range[ 0 ] , value_range[ 1 ] , size= ( len ( people) , len ( categories) ) ) plt. figure( figsize= figsize) sns. heatmap( data, annot= True , fmt= f'. { decimal_places} f' , cmap= cmap, xticklabels= categories, yticklabels= people, vmin= value_range[ 0 ] , vmax= value_range[ 1 ] ) plt. title( title) plt. xlabel( xlabel) plt. ylabel( ylabel) plt. tight_layout( ) return plt. gcf( )
if __name__ == "__main__" : people = [ '甲' , '乙' , '丙' ] categories = [ 'A' , 'B' , 'C' , 'D' , 'E' ] fig = create_heatmap( people, categories) plt. show( ) custom_data = np. array( [ [ 0.85 , 0.72 , 0.93 , 0.88 , 0.76 ] , [ 0.92 , 0.83 , 0.75 , 0.81 , 0.89 ] , [ 0.78 , 0.91 , 0.86 , 0.77 , 0.82 ] ] ) fig = create_heatmap( people= people, categories= categories, data= custom_data, title= '自定义数据热力图' , xlabel= '指标类别' , ylabel= '人员' , value_range= ( 0.7 , 0.95 ) , figsize= ( 12 , 8 ) , cmap= 'YlOrBl' , decimal_places= 2 ) plt. show( )
"""
# 使用不同的颜色方案
create_heatmap(people, categories, cmap='viridis')# 更改数值范围
create_heatmap(people, categories, value_range=(0, 1))# 自定义图形大小
create_heatmap(people, categories, figsize=(15, 8))# 调整小数位数
create_heatmap(people, categories, decimal_places=2)
"""
柱状图
import matplotlib. pyplot as plt
import numpy as npdef set_chinese_font ( ) : """设置中文字体""" plt. rcParams[ 'font.sans-serif' ] = [ 'SimHei' ] plt. rcParams[ 'axes.unicode_minus' ] = False def simple_bar ( categories, values, title= '简单柱状图' , xlabel= '类别' , ylabel= '数值' , color= 'skyblue' , figsize= ( 10 , 6 ) , show_values= True ) : """绘制简单柱状图参数:categories: 类别名称列表values: 数值列表title: 图表标题xlabel: x轴标签ylabel: y轴标签color: 柱子颜色figsize: 图形大小show_values: 是否显示数值标签""" set_chinese_font( ) plt. figure( figsize= figsize) bars = plt. bar( categories, values, color= color) plt. title( title) plt. xlabel( xlabel) plt. ylabel( ylabel) if show_values: for bar in bars: height = bar. get_height( ) plt. text( bar. get_x( ) + bar. get_width( ) / 2 . , height, f' { height: .1f } ' , ha= 'center' , va= 'bottom' ) plt. tight_layout( ) return plt. gcf( ) def grouped_bar ( categories, data_dict, title= '分组柱状图' , xlabel= '类别' , ylabel= '数值' , figsize= ( 12 , 6 ) , show_values= True ) : """绘制分组柱状图参数:categories: 类别名称列表data_dict: 数据字典,格式为 {'组名': [数值列表]}title: 图表标题xlabel: x轴标签ylabel: y轴标签figsize: 图形大小show_values: 是否显示数值标签""" set_chinese_font( ) plt. figure( figsize= figsize) n_groups = len ( categories) n_bars = len ( data_dict) bar_width = 0.8 / n_barscolors = plt. cm. Paired( np. linspace( 0 , 1 , n_bars) ) for idx, ( label, values) in enumerate ( data_dict. items( ) ) : x = np. arange( n_groups) + idx * bar_widthbars = plt. bar( x, values, bar_width, label= label, color= colors[ idx] ) if show_values: for bar in bars: height = bar. get_height( ) plt. text( bar. get_x( ) + bar. get_width( ) / 2 . , height, f' { height: .1f } ' , ha= 'center' , va= 'bottom' ) plt. title( title) plt. xlabel( xlabel) plt. ylabel( ylabel) plt. xticks( np. arange( n_groups) + ( bar_width * ( n_bars- 1 ) ) / 2 , categories) plt. legend( ) plt. tight_layout( ) return plt. gcf( ) def stacked_bar ( categories, data_dict, title= '堆叠柱状图' , xlabel= '类别' , ylabel= '数值' , figsize= ( 10 , 6 ) , show_values= True ) : """绘制堆叠柱状图参数:categories: 类别名称列表data_dict: 数据字典,格式为 {'组名': [数值列表]}title: 图表标题xlabel: x轴标签ylabel: y轴标签figsize: 图形大小show_values: 是否显示数值标签""" set_chinese_font( ) plt. figure( figsize= figsize) bottom = np. zeros( len ( categories) ) colors = plt. cm. Paired( np. linspace( 0 , 1 , len ( data_dict) ) ) for idx, ( label, values) in enumerate ( data_dict. items( ) ) : plt. bar( categories, values, bottom= bottom, label= label, color= colors[ idx] ) if show_values: for i, v in enumerate ( values) : plt. text( i, bottom[ i] + v/ 2 , f' { v: .1f } ' , ha= 'center' , va= 'center' ) bottom += valuesplt. title( title) plt. xlabel( xlabel) plt. ylabel( ylabel) plt. legend( ) plt. tight_layout( ) return plt. gcf( ) def horizontal_bar ( categories, values, title= '横向柱状图' , xlabel= '数值' , ylabel= '类别' , color= 'skyblue' , figsize= ( 10 , 6 ) , show_values= True ) : """绘制横向柱状图参数:categories: 类别名称列表values: 数值列表title: 图表标题xlabel: x轴标签ylabel: y轴标签color: 柱子颜色figsize: 图形大小show_values: 是否显示数值标签""" set_chinese_font( ) plt. figure( figsize= figsize) y_pos = np. arange( len ( categories) ) bars = plt. barh( y_pos, values, color= color) plt. yticks( y_pos, categories) plt. title( title) plt. xlabel( xlabel) plt. ylabel( ylabel) if show_values: for bar in bars: width = bar. get_width( ) plt. text( width, bar. get_y( ) + bar. get_height( ) / 2 . , f' { width: .1f } ' , ha= 'left' , va= 'center' ) plt. tight_layout( ) return plt. gcf( )
if __name__ == "__main__" : categories = [ '产品A' , '产品B' , '产品C' , '产品D' ] values = np. random. randint( 50 , 100 , size= len ( categories) ) group_data = { '2021年' : np. random. randint( 50 , 100 , size= len ( categories) ) , '2022年' : np. random. randint( 50 , 100 , size= len ( categories) ) , '2023年' : np. random. randint( 50 , 100 , size= len ( categories) ) } simple_bar( categories, values, title= '销售数据' ) plt. show( ) grouped_bar( categories, group_data, title= '年度销售对比' ) plt. show( ) stacked_bar( categories, group_data, title= '年度销售累计' ) plt. show( ) horizontal_bar( categories, values, title= '销售数据(横向)' ) plt. show( )