prepare_data.py 770 B

1234567891011121314151617181920212223
  1. import fire
  2. import os
  3. from datasets import load_dataset
  4. DATASET = "rahular/varta"
  5. def main(split="validation", lang="hi", docs_to_sample=10_000, save_path="data"):
  6. dataset = load_dataset(DATASET, split=split, streaming=True)
  7. os.makedirs(save_path, exist_ok=True)
  8. with open(os.path.join(save_path, f"{lang}.txt"), "w") as f:
  9. count = 0
  10. for idx, d in enumerate(dataset):
  11. if idx % 10_000 == 0:
  12. print(f"Searched {idx} documents for {lang} documents. Found {count} documents.")
  13. if count >= docs_to_sample:
  14. break
  15. if d["langCode"] == lang:
  16. f.write(d["headline"] + "\n" + d["text"] + "\n")
  17. count += 1
  18. if __name__ == "__main__":
  19. fire.Fire(main)