PrefixEncoder
class PrefixEncoder ( torch. nn. Module) : """The torch.nn model to encode the prefixInput shape: (batch-size, prefix-length)Output shape: (batch-size, prefix-length, 2*layers*hidden)""" def __init__ ( self, config: ChatGLMConfig) : super ( ) . __init__( ) self. prefix_projection = config. prefix_projectionif self. prefix_projection: kv_size = config. num_layers * config. kv_channels * config. multi_query_group_num * 2 self. embedding = torch. nn. Embedding( config. pre_seq_len, kv_size) self. trans = torch. nn. Sequential( torch. nn. Linear( kv_size, config. hidden_size) , torch. nn. Tanh( ) , torch. nn. Linear( config. hidden_size, kv_size) ) else : self. embedding = torch. nn. Embedding( config. pre_seq_len, config. num_layers * config. kv_channels * config. multi_query_group_num * 2 ) def forward ( self, prefix: torch. Tensor) : if self. prefix_projection: prefix_tokens = self. embedding( prefix) past_key_values = self. trans( prefix_tokens) else : past_key_values = self. embedding( prefix) return past_key_values
ChatGLMPreTrainedModel
class ChatGLMPreTrainedModel ( PreTrainedModel) : """An abstract class to handle weights initialization anda simple interface for downloading and loading pretrained models.""" is_parallelizable = False supports_gradient_checkpointing = True config_class = ChatGLMConfigbase_model_prefix = "transformer" _no_split_modules = [ "GLMBlock" ] def _init_weights ( self, module: nn. Module) : """Initialize the weights.""" return def get_masks ( self, input_ids, past_key_values, padding_mask= None ) : batch_size, seq_length = input_ids. shapefull_attention_mask = torch. ones( batch_size, seq_length, seq_length, device= input_ids. device) full_attention_mask. tril_( ) past_length = 0 if past_key_values: past_length = past_key_values[ 0 ] [ 0 ] . shape[ 0 ] if past_length: full_attention_mask = torch. cat( ( torch. ones( batch_size, seq_length, past_length, device= input_ids. device) , full_attention_mask) , dim= - 1 ) if padding_mask is not None : full_attention_mask = full_attention_mask * padding_mask. unsqueeze( 1 ) if not past_length and padding_mask is not None : full_attention_mask -= padding_mask. unsqueeze( - 1 ) - 1 full_attention_mask = ( full_attention_mask < 0.5 ) . bool ( ) full_attention_mask. unsqueeze_( 1 ) return full_attention_maskdef get_position_ids ( self, input_ids, device) : batch_size, seq_length = input_ids. shapeposition_ids = torch. arange( seq_length, dtype= torch. long , device= device) . unsqueeze( 0 ) . repeat( batch_size, 1 ) return position_idsdef _set_gradient_checkpointing ( self, module, value= False ) : if isinstance ( module, GLMTransformer) : module. gradient_checkpointing = value
ChatGLMForConditionalGeneration.stream_generate()
@torch. inference_mode ( ) def stream_generate ( self, input_ids, generation_config: Optional[ GenerationConfig] = None , logits_processor: Optional[ LogitsProcessorList] = None , stopping_criteria: Optional[ StoppingCriteriaList] = None , prefix_allowed_tokens_fn: Optional[ Callable[ [ int , torch. Tensor] , List[ int ] ] ] = None , return_past_key_values= False , ** kwargs, ) : batch_size, input_ids_seq_length = input_ids. shape[ 0 ] , input_ids. shape[ - 1 ] if generation_config is None : generation_config = self. generation_configgeneration_config = copy. deepcopy( generation_config) model_kwargs = generation_config. update( ** kwargs) model_kwargs[ "use_cache" ] = generation_config. use_cachebos_token_id, eos_token_id = generation_config. bos_token_id, generation_config. eos_token_idif isinstance ( eos_token_id, int ) : eos_token_id = [ eos_token_id] eos_token_id_tensor = torch. tensor( eos_token_id) . to( input_ids. device) if eos_token_id is not None else None has_default_max_length = kwargs. get( "max_length" ) is None and generation_config. max_length is not None if has_default_max_length and generation_config. max_new_tokens is None : warnings. warn( f"Using `max_length`'s default ( { generation_config. max_length} ) to control the generation length. " "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation." , UserWarning, ) elif generation_config. max_new_tokens is not None : generation_config. max_length = generation_config. max_new_tokens + input_ids_seq_lengthif not has_default_max_length: logger. warn( f"Both `max_new_tokens` (= { generation_config. max_new_tokens} ) and `max_length`(=" f" { generation_config. max_length} ) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" , UserWarning, ) if input_ids_seq_length >= generation_config. max_length: input_ids_string = "decoder_input_ids" if self. config. is_encoder_decoder else "input_ids" logger. warning( f"Input length of { input_ids_string} is { input_ids_seq_length} , but `max_length` is set to" f" { generation_config. max_length} . This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`." ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList( ) stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList( ) logits_processor = self. _get_logits_processor( generation_config= generation_config, input_ids_seq_length= input_ids_seq_length, encoder_input_ids= input_ids, prefix_allowed_tokens_fn= prefix_allowed_tokens_fn, logits_processor= logits_processor, ) stopping_criteria = self. _get_stopping_criteria( generation_config= generation_config, stopping_criteria= stopping_criteria) logits_warper = self. _get_logits_warper( generation_config) unfinished_sequences = input_ids. new( input_ids. shape[ 0 ] ) . fill_( 1 ) scores = None while True : model_inputs = self. prepare_inputs_for_generation( input_ids, ** model_kwargs) outputs = self( ** model_inputs, return_dict= True , output_attentions= False , output_hidden_states= False , ) next_token_logits = outputs. logits[ : , - 1 , : ] next_token_scores = logits_processor( input_ids, next_token_logits) next_token_scores = logits_warper( input_ids, next_token_scores) probs = nn. functional. softmax( next_token_scores, dim= - 1 ) if generation_config. do_sample: next_tokens = torch. multinomial( probs, num_samples= 1 ) . squeeze( 1 ) else : next_tokens = torch. argmax( probs, dim= - 1 ) input_ids = torch. cat( [ input_ids, next_tokens[ : , None ] ] , dim= - 1 ) model_kwargs = self. _update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder= self. config. is_encoder_decoder) unfinished_sequences = unfinished_sequences. mul( next_tokens. tile( eos_token_id_tensor. shape[ 0 ] , 1 ) . ne( eos_token_id_tensor. unsqueeze( 1 ) ) . prod( dim= 0 ) ) if return_past_key_values: yield input_ids, outputs. past_key_valueselse : yield input_idsif unfinished_sequences. max ( ) == 0 or stopping_criteria( input_ids, scores) : break