admin管理员组

文章数量:1334813

I am trying to perform fine-tuning on a base model with around 5-8 billion parameters. I have a dataset that results of combining the Dolly-15K and the alpaca-cleaned datasets. I want to perform a forward pass to get the logits of the output of the base model. However, when trying to do a forward pass through a batch of models using a DataLoader, I keep getting the same error. This pops up even when changing the model.

The code is the following:

# import necessary libraries
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments
)
from bitsandbytes.nn import Linear4bit

from trl import SFTTrainer
from evaluate import load
import time
import json
from tqdm import tqdm

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

# The dataset has been already saved locally
dataset = load_dataset("json", data_files="../data/combined.json")

# Load model
model_name = "Qwen/Qwen2.5-7B"
# Load the tokenizer for the model
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    add_eos_token=True,      # Add end-of-sequence token to the tokenizer
    use_fast=True,           # Use the fast tokenizer implementation
    padding_side='left',      # Pad sequences on the left side
)

# Commenting this line produces the same output error
tokenizer.pad_token = tokenizer.eos_token  # Set padding token to EOS token

# Quantization configuration using bitsandbytes library
compute_dtype = getattr(torch, "bfloat16")  # Set computation data type to bfloat16
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                    # Enable loading the model in 4-bit precision
    bnb_4bit_quant_type="nf4",            # Specify quantization type as Normal Float 4
    bnb_4bit_compute_dtype=compute_dtype, # Set computation data type
    bnb_4bit_use_double_quant=True,       # Use double quantization for better accuracy
)

# Load the pre-trained model with the specified quantization configuration
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,  # Apply quantization configuration
    device_map="cpu"                # Automatically map layers to device
)

def format_data(example):
    return tokenizer(example["prompt"], padding="max_length", truncation=True, max_length=512, return_attention_mask=True)

tokenized_dataset = dataset.map(format_data, batched=True)

tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

inputs_dl = DataLoader(tokenized_dataset["train"], batch_size=8)

# Set to evaluation mode
model.eval()

# Initialize a progress bar
total_batches = len(inputs_dl)
progress_bar = tqdm(inputs_dl, total=total_batches, desc="Processing Batches")

# Disable gradient calculations for inference
with torch.no_grad():
    for batch in inputs_dl:
        # Move input tensors to the device
        input_ids = batch["input_ids"].to("cpu")
        attention_mask = batch["attention_mask"].to("cpu")
        
        # Perform the forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

No further code is provided because it pops up in this last line, without performing a single iteration.

Note: both the model and the inputs are in the cpu to see the error, because with cuda I could not see it.

This is the error I am getting:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[8], line 25
     22 attention_mask = batch[\"attention_mask\"].to(\"cpu\")
     24 # Perform the forward pass
---> 25 outputs = model(input_ids=input_ids, attention_mask=attention_mask)
     27 # Extract logits
     28 logits = outputs.logits  # Shape: (batch_size, seq_length, vocab_size)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.local/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1167, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   1164 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1166 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1167 outputs = self.model(
   1168     input_ids=input_ids,
   1169     attention_mask=attention_mask,
   1170     position_ids=position_ids,
   1171     past_key_values=past_key_values,
   1172     inputs_embeds=inputs_embeds,
   1173     use_cache=use_cache,
   1174     output_attentions=output_attentions,
   1175     output_hidden_states=output_hidden_states,
   1176     return_dict=return_dict,
   1177     cache_position=cache_position,
   1178 )
   1180 hidden_states = outputs[0]
   1181 if labels is None and not is_torchdynamo_compiling():

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.local/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:976, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    964     layer_outputs = self._gradient_checkpointing_func(
    965         decoder_layer.__call__,
    966         hidden_states,
   (...)
    973         position_embeddings,
    974     )
    975 else:
--> 976     layer_outputs = decoder_layer(
    977         hidden_states,
    978         attention_mask=causal_mask,
    979         position_ids=position_ids,
    980         past_key_value=past_key_values,
    981         output_attentions=output_attentions,
    982         use_cache=use_cache,
    983         cache_position=cache_position,
    984         position_embeddings=position_embeddings,
    985     )
    987 hidden_states = layer_outputs[0]
    989 if use_cache:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.local/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:702, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    699 hidden_states = self.input_layernorm(hidden_states)
    701 # Self Attention
--> 702 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    703     hidden_states=hidden_states,
    704     attention_mask=attention_mask,
    705     position_ids=position_ids,
    706     past_key_value=past_key_value,
    707     output_attentions=output_attentions,
    708     use_cache=use_cache,
    709     cache_position=cache_position,
    710     position_embeddings=position_embeddings,
    711 )
    712 hidden_states = residual + hidden_states
    714 # Fully Connected

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.local/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:580, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings)
    569     return super().forward(
    570         hidden_states=hidden_states,
    571         attention_mask=attention_mask,
   (...)
    575         use_cache=use_cache,
    576     )
    578 bsz, q_len, _ = hidden_states.size()
--> 580 query_states = self.q_proj(hidden_states)
    581 key_states = self.k_proj(hidden_states)
    582 value_states = self.v_proj(hidden_states)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.local/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:469, in Linear4bit.forward(self, x)
    468 def forward(self, x: torch.Tensor):
--> 469     fix_4bit_weight_quant_state_from_module(self)
    471     # weights are cast automatically as Int8Params, but the bias has to be cast manually
    472     if self.bias is not None and self.bias.dtype != x.dtype:

File ~/.local/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:361, in fix_4bit_weight_quant_state_from_module(module)
    355     warnings.warn(
    356         \"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.\",
    357     )
    359 # the quant state got lost when the parameter got converted. This happens for example for fsdp
    360 # since we registered the module, we can recover the state here
--> 361 assert module.weight.shape[1] == 1
    362 if not isinstance(module.weight, Params4bit):
    363     module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)

AssertionError:

本文标签: pythonBatch forward huggingface transformer errorStack Overflow