admin管理员组文章数量:1277290
The following program crashes upon execution
from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer
prompts = ["Hi", "Hello"]
def data_generator():
while True:
for s in prompts:
yield {"prompt" : s}
dataset = IterableDataset.from_generator(data_generator)
training_args = GRPOConfig(
output_dir= "tmp",
max_steps = 1000,
)
trainer = GRPOTrainer(
model="facebook/opt-350m",
reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
train_dataset=dataset,
args=training_args,
)
trainer.train()
Causes the following trace:
Traceback (most recent call last):
File "/home/pietro/Documents/Code/CS234/starter_code/trl_testing.py", line 24, in <module>
trainer.train()
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 2241, in train
return inner_training_loop(
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 2500, in _inner_training_loop
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 5180, in get_batch_samples
batch_samples += [next(epoch_iterator)]
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/data_loader.py", line 856, in __iter__
next_batch, next_batch_info = self._fetch_batches(main_iterator)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/data_loader.py", line 812, in _fetch_batches
batch = concatenate(batches, dim=0)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 615, in concatenate
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 81, in honor_type
return type(obj)(generator)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 615, in <genexpr>
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 617, in concatenate
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 617, in <dictcomp>
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 619, in concatenate
raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
TypeError: Can only concatenate tensors but got <class 'str'>
However, replacing the IterableDataset
as done below with an analogous Dataset
fixes the issue:
from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer
prompts = ["Hi", "Hello"]
dataset = Dataset.from_dict({"prompt" : prompts})
training_args = GRPOConfig(
output_dir= "tmp",
max_steps = 1000,
)
trainer = GRPOTrainer(
model="facebook/opt-350m",
reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
train_dataset=dataset,
args=training_args,
)
trainer.train()
This has been reproduced on 2 very different systems, so unlikely that this is the cause. Am I missing something?
本文标签: pythonIterableDataset not supported on GRPOTrainerStack Overflow
版权声明:本文标题:python - IterableDataset not supported on GRPOTrainer - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1741292848a2370648.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论