school_math_dataset.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import json
  2. import random
  3. import datasets
  4. def get_preprocessed_school_math(dataset_config, tokenizer, split):
  5. with open("/home/setup/.datasets/jasnah-school-math/precise/grades4to6.json") as f:
  6. data = json.load(f)
  7. # Shuffle data
  8. rng = random.Random(42)
  9. order = list(range(len(data)))
  10. rng.shuffle(order)
  11. data = [data[i] for i in order]
  12. train_percent = 70
  13. train_count = int(len(data) * train_percent / 100)
  14. if split == "train":
  15. data = data[:train_count]
  16. else:
  17. data = data[train_count:]
  18. data = {
  19. "problem": [d["problem"] for d in data],
  20. "solution": [d["solution"] for d in data],
  21. "answer": [d["answer"] for d in data],
  22. }
  23. dataset = datasets.Dataset.from_dict(data)
  24. prompt_template = f"Solve the following math problem based on the provided description. The description provided is in Russian and uses LaTeX in some places.\nYou should include a detailed step-by-step explanation of your solution. The final response must clearly present the precise answer.\n<problem>\n{{problem}}\n</problem>\n"
  25. response_template = (
  26. f"<solution>{{solution}}</solution>\n<answer>{{answer}}</answer>\n"
  27. )
  28. def apply_prompt_template(sample):
  29. return {
  30. "prompt": prompt_template.format(problem=sample["problem"]),
  31. "response": response_template.format(
  32. solution=sample["solution"], answer=sample["answer"]
  33. ),
  34. }
  35. dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
  36. def tokenize_add_label(sample):
  37. prompt = tokenizer.encode(
  38. tokenizer.bos_token + sample["prompt"], add_special_tokens=False
  39. )
  40. response = tokenizer.encode(
  41. sample["response"] + tokenizer.eos_token, add_special_tokens=False
  42. )
  43. sample = {
  44. "input_ids": prompt + response,
  45. "attention_mask": [1] * (len(prompt) + len(response)),
  46. "labels": [-100] * len(prompt) + response,
  47. }
  48. return sample
  49. dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
  50. return dataset