admin管理员组

文章数量:1277290

The following program crashes upon execution

from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer

prompts = ["Hi", "Hello"]
def data_generator():
    while True:
        for s in prompts:
            yield {"prompt" : s}
dataset = IterableDataset.from_generator(data_generator)


training_args = GRPOConfig(
    output_dir= "tmp",
    max_steps = 1000,
)

trainer = GRPOTrainer(
    model="facebook/opt-350m",
    reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
    train_dataset=dataset,
    args=training_args,
)

trainer.train()

Causes the following trace:

Traceback (most recent call last):
  File "/home/pietro/Documents/Code/CS234/starter_code/trl_testing.py", line 24, in <module>
    trainer.train()
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 2500, in _inner_training_loop
    batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 5180, in get_batch_samples
    batch_samples += [next(epoch_iterator)]
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/data_loader.py", line 856, in __iter__
    next_batch, next_batch_info = self._fetch_batches(main_iterator)
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/data_loader.py", line 812, in _fetch_batches
    batch = concatenate(batches, dim=0)
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 615, in concatenate
    return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 81, in honor_type
    return type(obj)(generator)
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 615, in <genexpr>
    return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 617, in concatenate
    return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 617, in <dictcomp>
    return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
  File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 619, in concatenate
    raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
TypeError: Can only concatenate tensors but got <class 'str'>

However, replacing the IterableDataset as done below with an analogous Dataset fixes the issue:

from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer

prompts = ["Hi", "Hello"]
dataset = Dataset.from_dict({"prompt" : prompts})

training_args = GRPOConfig(
    output_dir= "tmp",
    max_steps = 1000,
)

trainer = GRPOTrainer(
    model="facebook/opt-350m",
    reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
    train_dataset=dataset,
    args=training_args,
)

trainer.train()

This has been reproduced on 2 very different systems, so unlikely that this is the cause. Am I missing something?

本文标签: pythonIterableDataset not supported on GRPOTrainerStack Overflow