构建卷积网络
# include <torch/torch.h>
# include <torch/script.h>
# include <iostream> using std:: cout; using std:: endl; class LinearBnReluImpl : public torch:: nn:: Module
{
private : torch:: nn:: Linear ln{ nullptr } ; torch:: nn:: BatchNorm1d bn{ nullptr } ; public : LinearBnReluImpl ( int input_features, int out_features) ; torch:: Tensor forward ( torch:: Tensor x) ;
} ;
TORCH_MODULE ( LinearBnRelu) ; inline torch:: nn:: Conv2dOptions conv_options ( int64_t in_planes, int64_t out_planes, int64_t kernel_size, int64_t stride = 1 , int64_t padding = 0 , bool with_bias = false
)
{ torch:: nn:: Conv2dOptions conv_options = torch:: nn:: Conv2dOptions ( in_planes, out_planes, kernel_size) ; conv_options. stride ( stride) ; conv_options. padding ( padding) ; conv_options. bias ( with_bias) ; return conv_options;
} class ConvReluBnImpl : public torch:: nn:: Module
{
private : torch:: nn:: Conv2d conv{ nullptr } ; torch:: nn:: BatchNorm2d bn{ nullptr } ; public : ConvReluBnImpl ( int input_channel, int output_channel, int kernel_size, int stride, int padding= 1 ) ; torch:: Tensor forward ( torch:: Tensor x) ;
} ;
TORCH_MODULE ( ConvReluBn) ; class MLP : public torch:: nn:: Module
{
private : int mid_features[ 3 ] = { 32 , 64 , 128 } ; LinearBnRelu ln1{ nullptr } ; LinearBnRelu ln2{ nullptr } ; LinearBnRelu ln3{ nullptr } ; torch:: nn:: Linear out_ln{ nullptr } ; public : MLP ( int in_features, int out_features) ; torch:: Tensor forward ( torch:: Tensor x) ;
} ; class plainCNN : public torch:: nn:: Module
{
private : int mid_channels[ 3 ] { 32 , 64 , 128 } ; ConvReluBn conv1{ nullptr } ; ConvReluBn down1{ nullptr } ; ConvReluBn conv2{ nullptr } ; ConvReluBn down2{ nullptr } ; ConvReluBn conv3{ nullptr } ; ConvReluBn down3{ nullptr } ; torch:: nn:: Conv2d out_conv{ nullptr } ; public : plainCNN ( int in_channels, int out_channels) ; torch:: Tensor forward ( torch:: Tensor x) ;
} ; int main ( )
{ plainCNN c ( 3 , 2 ) ; auto x = torch:: rand ( { 1 , 3 , 224 , 224 } , torch:: kFloat) ; auto a = c. forward ( x) ; cout << "[in Main]: " << a. sizes ( ) << endl; return 0 ;
} LinearBnReluImpl :: LinearBnReluImpl ( int input_features, int out_features)
{ ln = register_module ( "ln" , torch:: nn:: Linear ( torch:: nn:: LinearOptions ( input_features, out_features) ) ) ; bn = register_module ( "bn" , torch:: nn:: BatchNorm1d ( out_features) ) ;
} torch:: Tensor LinearBnReluImpl :: forward ( torch:: Tensor x)
{ x = torch:: relu ( ln-> forward ( x) ) ; x = bn ( x) ; return x;
} ConvReluBnImpl :: ConvReluBnImpl ( int input_channel, int output_channel, int kernel_size, int stride, int padding)
{ conv = register_module ( "conv" , torch:: nn:: Conv2d ( conv_options ( input_channel, output_channel, kernel_size, stride, padding) ) ) ; bn = register_module ( "bn" , torch:: nn:: BatchNorm2d ( output_channel) ) ;
} torch:: Tensor ConvReluBnImpl :: forward ( torch:: Tensor x)
{ x = torch:: relu ( conv-> forward ( x) ) ; x = bn ( x) ; return x;
} MLP :: MLP ( int in_features, int out_features)
{ ln1 = LinearBnRelu ( in_features, mid_features[ 0 ] ) ; ln2 = LinearBnRelu ( mid_features[ 0 ] , mid_features[ 1 ] ) ; ln3 = LinearBnRelu ( mid_features[ 1 ] , mid_features[ 2 ] ) ; out_ln = torch:: nn:: Linear ( mid_features[ 2 ] , out_features) ; ln1 = register_module ( "ln1" , ln1) ; ln2 = register_module ( "ln2" , ln2) ; ln3 = register_module ( "ln3" , ln3) ; out_ln = register_module ( "out_ln" , out_ln) ;
} torch:: Tensor MLP :: forward ( torch:: Tensor x)
{ x = ln1-> forward ( x) ; x = ln2-> forward ( x) ; x = ln3-> forward ( x) ; x = out_ln-> forward ( x) ; return x;
} plainCNN:: plainCNN ( int in_channels, int out_channels)
{ conv1 = ConvReluBn ( in_channels, mid_channels[ 0 ] , 3 , 1 ) ; down1 = ConvReluBn ( mid_channels[ 0 ] , mid_channels[ 0 ] , 3 , 2 ) ; conv2 = ConvReluBn ( mid_channels[ 0 ] , mid_channels[ 1 ] , 3 , 1 ) ; down2 = ConvReluBn ( mid_channels[ 1 ] , mid_channels[ 1 ] , 3 , 2 ) ; conv3 = ConvReluBn ( mid_channels[ 1 ] , mid_channels[ 2 ] , 3 , 1 ) ; down3 = ConvReluBn ( mid_channels[ 2 ] , mid_channels[ 2 ] , 3 , 2 ) ; out_conv = torch:: nn:: Conv2d ( conv_options ( mid_channels[ 2 ] , out_channels, 3 ) ) ; conv1 = register_module ( "conv1" , conv1) ; down1 = register_module ( "down1" , down1) ; conv2 = register_module ( "conv2" , conv2) ; down2 = register_module ( "down2" , down2) ; conv3 = register_module ( "conv3" , conv3) ; down3 = register_module ( "down3" , down3) ; out_conv = register_module ( "out_conv" , out_conv) ;
} torch:: Tensor plainCNN:: forward ( torch:: Tensor x)
{ x = conv1-> forward ( x) ; cout << x. sizes ( ) << endl; x = down1-> forward ( x) ; cout << x. sizes ( ) << endl; x = conv2-> forward ( x) ; cout << x. sizes ( ) << endl; x = down2-> forward ( x) ; cout << x. sizes ( ) << endl; x = conv3-> forward ( x) ; cout << x. sizes ( ) << endl; x = down3-> forward ( x) ; cout << x. sizes ( ) << endl; x = out_conv-> forward ( x) ; cout << x. sizes ( ) << endl; return x;
}
结果