{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) Meta Platforms, Inc. and affiliates.\n",
    "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.\n",
    "\n",
    "Use this notebook to pull in datasets and apply pre-processing.  Most grammar datasets unfortunately require preprocessing before being usable in training. (example - jfleg has 4 targets per input, so we have to rematch as 1:1 pairings) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],

   "source": [
    "import csv\n",
    "from datasets import load_metric, load_dataset\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_replacements = [\n",
    "  (\" .\", \".\"), \n",
    "  (\" ,\", \",\"),\n",
    "  (\" '\", \"'\"),\n",
    "  (\" ?\", \"?\"),\n",
    "  (\" !\", \"!\"),\n",
    "  (\" :\", \"!\"),\n",
    "  (\" ;\", \"!\"),\n",
    "  (\" n't\", \"n't\"),\n",
    "  (\" v\", \"n't\"),\n",
    "  (\"2 0 0 6\", \"2006\"),\n",
    "  (\"5 5\", \"55\"),\n",
    "  (\"4 0 0\", \"400\"),\n",
    "  (\"1 7-5 0\", \"1750\"),\n",
    "  (\"2 0 %\", \"20%\"),\n",
    "  (\"5 0\", \"50\"),\n",
    "  (\"1 2\", \"12\"),\n",
    "  (\"1 0\", \"10\"),\n",
    "  ('\" ballast water', '\"ballast water')\n",
    "  ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def correct_spacing(item):\n",
    "    \"\"\" we iterate through the list of all replacements per each item in dataset\"\"\"\n",
    "    for fix in list_replacements:\n",
    "        item = item.replace(fix[0], fix[1])\n",
    "    return item\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_csv(csv_path, dataset):\n",
    "    \"\"\" apply spacing corrections and save out matched pairs to csv file as dataset\"\"\"\n",
    "    with open(csv_path, 'w', newline='') as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        writer.writerow([\"input\", \"target\"])\n",
    "        for case in dataset:\n",
    "     \t    # Adding the t5 task indication prefix to input \n",
  
    "            input_text = case[\"sentence\"]\n",

    "            input_text = correct_spacing(input_text)\n",
    "\n",
    "            for correction in case[\"corrections\"]:\n",
    "              correction = correct_spacing(correction)\n",
    "              # a few of the cases contain blank strings. \n",
    "              if input_text and correction:\n",
    "                writer.writerow([input_text, correction])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In Jfleg  - validation will be used as 'train', test will be 'validation'"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 5,

   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [

      "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n",
      "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n"

     ]
    }
   ],
   "source": [
    "train_dataset = load_dataset(\"jfleg\", split='validation[:]') \n",
    "eval_dataset = load_dataset(\"jfleg\", split='test[:]')\n"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 6,

   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['sentence', 'corrections'],\n",
      "    num_rows: 755\n",
      "})\n",
      "Dataset({\n",
      "    features: ['sentence', 'corrections'],\n",
      "    num_rows: 748\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "print(train_dataset)\n",
    "print(eval_dataset)\n"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 7,

   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas . \n",
      "['Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ']\n"
     ]
    }
   ],
   "source": [
    "print(train_dataset['sentence'][22])\n",
    "print(train_dataset['corrections'][22])"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 8,

   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas. '"
      ]
     },

     "execution_count": 8,

     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clean22 = correct_spacing(train_dataset['sentence'][22])\n",
    "clean22"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 9,

   "metadata": {},
   "outputs": [],
   "source": [
    "jfleg_dir = Path.cwd()/'jfleg_dataset'  # if you only use 'jfleg', hf will try and use that and complain\n",
    "jfleg_dir.mkdir(parents=True,exist_ok=True)\n",
    "c4_dir = Path.cwd()/'c4_dataset'\n",
    "c4_dir.mkdir(parents=True,exist_ok=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Process Jfleg data  "
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 10,

   "metadata": {},
   "outputs": [],
   "source": [
    "j_train_file = jfleg_dir/'jtrain.csv'\n",
    "j_eval_file = jfleg_dir/'jeval.csv'"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 11,

   "metadata": {},
   "outputs": [],
   "source": [
    "generate_csv(j_train_file, train_dataset)"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 12,

   "metadata": {},
   "outputs": [],
   "source": [
    "generate_csv(j_eval_file, eval_dataset)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Process C4_200M (!) - we'll pull 10K to start"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 13,

   "metadata": {},
   "outputs": [],
   "source": [
    "c4_dataset = load_dataset(\"liweili/c4_200m\", streaming = True)"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 14,

   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = iter(c4_dataset['train'])"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 15,

   "metadata": {},
   "outputs": [],
   "source": [
    "def c4_generate_csv(csv_path, iterator, num_examples):\n",
    "    with open(csv_path, 'w', newline='') as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        writer.writerow([\"input\", \"target\"])\n",
    "        for i in range(0,num_examples):\n",
    "          data = next(iterator)\n",

    "          input_text = data[\"input\"]\n",

    "          input_text = correct_spacing(input_text)\n",
    "          correction = correct_spacing(data[\"output\"])\n",
    "          if input_text and correction:\n",
    "            writer.writerow([input_text, correction])"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 16,

   "metadata": {},
   "outputs": [],
   "source": [
    "c4_dir = Path.cwd()/'c4_dataset'\n",
    "c4_dir.mkdir(parents=True,exist_ok=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can modify the following to make the csv file with desired number of instances, here we go for 10k to make a quick test"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 17,

   "metadata": {},
   "outputs": [],
   "source": [
    "c4_filename = c4_dir/'c4train_10k.csv'"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 18,

   "metadata": {},
   "outputs": [],
   "source": [
    "c4_generate_csv(c4_filename, iterator, num_examples=10000)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a single training file by combining jtrain and c4train"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 19,

   "metadata": {},
   "outputs": [],
   "source": [
    "merge_list = [j_train_file, c4_filename, ]"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 20,

   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 21,

   "metadata": {},
   "outputs": [],
   "source": [
    "combined_csv = pd.concat([pd.read_csv(fn) for fn in merge_list])\n"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 22,

   "metadata": {},
   "outputs": [],
   "source": [
    "merged_name = \"gtrain_10k.csv\""
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 23,

   "metadata": {},
   "outputs": [],
   "source": [
    "combined_csv.to_csv(merged_name, index=False, encoding = 'utf-8-sig', )"
   ]
  },
  {
   "cell_type": "code",

   "execution_count": 24,

   "metadata": {},
   "outputs": [],
   "source": [
    "eval_name = \"grammar_validation.csv\""
   ]

  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_csv = pd.read_csv(j_eval_file)\n",
    "eval_csv.to_csv(eval_name, index=False, encoding = 'utf-8-sig', )"
   ]

  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "5b2c14c5f2a3b21e6c2412c8196f5145870350e81c0b737cae3e5c60eb1e1eac"
  },
  "kernelspec": {

   "display_name": "Python 3 (ipykernel)",

   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"

  }
 },
 "nbformat": 4,
 "nbformat_minor": 4

}