123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import json
- import random
- import datasets
- def get_preprocessed_school_math(dataset_config, tokenizer, split):
- with open("/home/setup/.datasets/jasnah-school-math/precise/grades4to6.json") as f:
- data = json.load(f)
- # Shuffle data
- rng = random.Random(42)
- order = list(range(len(data)))
- rng.shuffle(order)
- data = [data[i] for i in order]
- train_percent = 70
- train_count = int(len(data) * train_percent / 100)
- if split == "train":
- data = data[:train_count]
- else:
- data = data[train_count:]
- data = {
- "problem": [d["problem"] for d in data],
- "solution": [d["solution"] for d in data],
- "answer": [d["answer"] for d in data],
- }
- dataset = datasets.Dataset.from_dict(data)
- prompt_template = f"Solve the following math problem based on the provided description. The description provided is in Russian and uses LaTeX in some places.\nYou should include a detailed step-by-step explanation of your solution. The final response must clearly present the precise answer.\n<problem>\n{{problem}}\n</problem>\n"
- response_template = (
- f"<solution>{{solution}}</solution>\n<answer>{{answer}}</answer>\n"
- )
- def apply_prompt_template(sample):
- return {
- "prompt": prompt_template.format(problem=sample["problem"]),
- "response": response_template.format(
- solution=sample["solution"], answer=sample["answer"]
- ),
- }
- dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
- def tokenize_add_label(sample):
- prompt = tokenizer.encode(
- tokenizer.bos_token + sample["prompt"], add_special_tokens=False
- )
- response = tokenizer.encode(
- sample["response"] + tokenizer.eos_token, add_special_tokens=False
- )
- sample = {
- "input_ids": prompt + response,
- "attention_mask": [1] * (len(prompt) + len(response)),
- "labels": [-100] * len(prompt) + response,
- }
- return sample
- dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
- return dataset
|