時間がかかったけど、なんとかzarr化できたのでメモ。
ドキュメントはここ https://github.com/real-stanford/diffusion_policy?tab=readme-ov-file#replaybuffer
data/pusht_cchi_v7_replay.zarr ├── data │ ├── action (25650, 2) float32 │ ├── img (25650, 96, 96, 3) float32 │ ├── keypoint (25650, 9, 2) float32 │ ├── n_contacts (25650, 1) float32 │ └── state (25650, 5) float32 └── meta └── episode_ends (206,) int64
この構成のzarrデータを作る必要がある。 dataの下に入力データを並べ、metaの下にepisode_endsの配列を用意する。 入力データは全部まとめて一連の配列にする。配列の区切り(エピソードの区切り)をepisode_endsに入れておく。
zarr.saveだけではこの構造で保存できず。
store = zarr.DirectoryStore(OUTPUT_FILE)
root = zarr.group(store=store, overwrite=True)
こうやってから、個別にroot.create_datasetで上手くデータ作れました。
import os import glob import zarr from PIL import Image import numpy as np ROOTDIR = 'dofbot_data/OK' OUTPUT_FILE = 'dofbot.zarr' def create_state(root_dir:str) -> zarr.core.Array: # ディレクトリのリストを取得 pose_list = [] for d in sorted(os.listdir(root_dir)): # pose.csvの読み込み pose_csv = os.path.join(root_dir, d, 'pose.csv') with open(pose_csv, 'r') as f: # 1行ずつ読み込む for line in f: pose_words = line.split(',') # float型のリストに変換 pose = [float(w) for w in pose_words] pose_list.append(pose) # zarr配列に変換 z = zarr.array(pose_list, chunks=(1, len(pose_list[0])), dtype='float32') return z def create_images(root_dir:str) -> zarr.core.Array: images = [] for d in sorted(os.listdir(root_dir)): # ディレクトリの中のpngファイルのリストを取得 files = sorted(glob.glob(os.path.join(root_dir, d, '*.png'))) for f in files: # 画像の読み込み img = Image.open(f) # 画像をリサイズ img = img.resize((480, 480)) # 画像をuint8の配列に変換 img_arr = np.array(img).astype('uint8') images.append(img_arr) # zarr配列に変換 z = zarr.array(images, chunks=(1, 480, 480, 3), dtype='float32') return z def calc_episode_ends(dir: str) -> list: # ディレクトリのリストを取得してソート dirs = sorted(os.listdir(dir)) episode_ends = [] ei = 0 for d in dirs: # ディレクトリの中のpngファイルのリストを取得 files = glob.glob(os.path.join(dir, d, '*.png')) ei += len(files) episode_ends.append(ei) return episode_ends def create_action(states: zarr.core.Array, episode_ends:list) -> zarr.core.Array: actions = [] for idx in range(len(states)): if idx + 1 in episode_ends: action = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] else: action_now = states[idx] action_next = states[idx+1] action = [action_next[0] - action_now[0], action_next[1] - action_now[1], action_next[2] - action_now[2], action_next[3] - action_now[3], action_next[4] - action_now[4], action_next[5] - action_now[5]] actions.append(action) z = zarr.array(actions, chunks=(1,), dtype='float32') return z def main(): episode_ends = calc_episode_ends(ROOTDIR) states = create_state(ROOTDIR) actions = create_action(states, episode_ends) images = create_images(ROOTDIR) store = zarr.DirectoryStore(OUTPUT_FILE) root = zarr.group(store=store, overwrite=True) root.create_dataset('meta/episode_ends', data=episode_ends) root.create_dataset('data/state', data=states) root.create_dataset('data/action', data=actions) root.create_dataset('data/img', data=images) print(root.info) print('saved', OUTPUT_FILE) if __name__ == '__main__': main()