admin管理员组

文章数量:1387303

I'm trying to fine-tune a model using SFTTrainer from trl, but I'm facing multiple TypeError issues related to unexpected keyword arguments.

from transformers import TrainingArguments
from trl import SFTTrainer

output_dir = "tinyllama_instruct"
training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_steps=25,
    learning_rate=2e-5,
    max_grad_norm=1.0,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    fp16=True,
    report_to=["tensorboard", "wandb"],
    num_train_epochs=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    packing=True,  # Causes TypeError
    dataset_text_field="content",  # Causes TypeError if packing is removed
    max_seq_length=2048,  # Causes TypeError if dataset_text_field is removed
)

The Notebook can be found here: .ipynb

Errors Encountered:

  • TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'packing'
  • Removing packing=True results in:
    TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'dataset_text_field'
  • Removing dataset_text_field="content" results in:
    TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'max_seq_length'
  • Finally, when I remove all these arguments, I get a KeyError: 'text' while tokenizing.

What I’ve Tried:

  • Removing the problematic arguments one by one, but each time a new issue arises.
  • Checking the latest trl documentation, but packing, dataset_text_field, and max_seq_length don't seem to be part of SFTTrainer anymore.
  • Verifying my dataset structure.

Question:

  • Has the SFTTrainer API changed recently, and are these arguments deprecated?
  • How should I correctly pass max_seq_length and specify the text field in my dataset?
  • Is packing handled differently now?

Any guidance would be greatly appreciated!

本文标签: