def sample_by_num(data_dict: dict, num: int): """ Sample num trajs from data_dict. """ samples = {} for k, v in data_dict.items(): if k == "index": samples[k] = v[0: num] else: samples[k] = v[0: int(data_dict["index"][num])] return samples from: offlinerl/neorl