from basedformer.dataset import ShardedImageDataset
import time
import random

dataset_folder = "/home/xuser/nvme1/dataset/danbooru_updated_page/"

dataset = ShardedImageDataset(dataset_folder + "danbooru_updated.ds", dataset_folder + "danbooru_updated.index", None, bsz=1)

#read through first to normalize times / cache
total_size = 0
for key in dataset.pointer_lookup.keys():
    offset, size = dataset.pointer_lookup[key]
    data = list(dataset.mmap[offset:offset+size]) # convert to list to ensure data is actually read
    total_size = total_size + size


#benchmark sequential read
start_time = time.time()
for key in dataset.pointer_lookup.keys():
    offset, size = dataset.pointer_lookup[key]
    data = list(dataset.mmap[offset:offset+size]) # convert to list to ensure data is actually read
end_time = time.time()
print("Sequential read took " + str(end_time - start_time) + " seconds to read " + str(total_size//1024//1024) + "MiB")

#benchmark random read
keys = list(dataset.pointer_lookup.keys())
random.shuffle(keys)
start_time = time.time()
for key in keys:
    offset, size = dataset.pointer_lookup[key]
    data = list(dataset.mmap[offset:offset+size]) # convert to list to ensure data is actually read
end_time = time.time()
print("Random read took " + str(end_time - start_time) + " seconds to read " + str(total_size//1024//1024) + "MiB")