eval_main.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import argparse
  2. import json
  3. import logging
  4. import os
  5. import re
  6. import sys
  7. from pathlib import Path
  8. from typing import Union
  9. import numpy as np
  10. from lm_eval import evaluator, utils
  11. from lm_eval.api.registry import ALL_TASKS
  12. from lm_eval.tasks import include_path, initialize_tasks
  13. from lm_eval.utils import make_table
  14. def _handle_non_serializable(o):
  15. if isinstance(o, np.int64) or isinstance(o, np.int32):
  16. return int(o)
  17. elif isinstance(o, set):
  18. return list(o)
  19. else:
  20. return str(o)
  21. def parse_eval_args() -> argparse.Namespace:
  22. parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
  23. parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
  24. parser.add_argument(
  25. "--tasks",
  26. "-t",
  27. default=None,
  28. metavar="task1,task2",
  29. help="To get full list of tasks, use the command lm-eval --tasks list",
  30. )
  31. parser.add_argument(
  32. "--model_args",
  33. "-a",
  34. default="",
  35. help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
  36. )
  37. parser.add_argument(
  38. "--num_fewshot",
  39. "-f",
  40. type=int,
  41. default=None,
  42. metavar="N",
  43. help="Number of examples in few-shot context",
  44. )
  45. parser.add_argument(
  46. "--batch_size",
  47. "-b",
  48. type=str,
  49. default=1,
  50. metavar="auto|auto:N|N",
  51. help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
  52. )
  53. parser.add_argument(
  54. "--max_batch_size",
  55. type=int,
  56. default=None,
  57. metavar="N",
  58. help="Maximal batch size to try with --batch_size auto.",
  59. )
  60. parser.add_argument(
  61. "--device",
  62. type=str,
  63. default=None,
  64. help="Device to use (e.g. cuda, cuda:0, cpu).",
  65. )
  66. parser.add_argument(
  67. "--output_path",
  68. "-o",
  69. default=None,
  70. type=str,
  71. metavar="DIR|DIR/file.json",
  72. help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
  73. )
  74. parser.add_argument(
  75. "--limit",
  76. "-L",
  77. type=float,
  78. default=None,
  79. metavar="N|0<N<1",
  80. help="Limit the number of examples per task. "
  81. "If <1, limit is a percentage of the total number of examples.",
  82. )
  83. parser.add_argument(
  84. "--use_cache",
  85. "-c",
  86. type=str,
  87. default=None,
  88. metavar="DIR",
  89. help="A path to a sqlite db file for caching model responses. `None` if not caching.",
  90. )
  91. parser.add_argument("--decontamination_ngrams_path", default=None) # TODO: not used
  92. parser.add_argument(
  93. "--check_integrity",
  94. action="store_true",
  95. help="Whether to run the relevant part of the test suite for the tasks.",
  96. )
  97. parser.add_argument(
  98. "--write_out",
  99. "-w",
  100. action="store_true",
  101. default=False,
  102. help="Prints the prompt for the first few documents.",
  103. )
  104. parser.add_argument(
  105. "--log_samples",
  106. "-s",
  107. action="store_true",
  108. default=False,
  109. help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
  110. )
  111. parser.add_argument(
  112. "--show_config",
  113. action="store_true",
  114. default=False,
  115. help="If True, shows the the full config of all tasks at the end of the evaluation.",
  116. )
  117. parser.add_argument(
  118. "--include_path",
  119. type=str,
  120. default=None,
  121. metavar="DIR",
  122. help="Additional path to include if there are external tasks to include.",
  123. )
  124. parser.add_argument(
  125. "--gen_kwargs",
  126. default=None,
  127. help=(
  128. "String arguments for model generation on greedy_until tasks,"
  129. " e.g. `temperature=0,top_k=0,top_p=0`."
  130. ),
  131. )
  132. parser.add_argument(
  133. "--verbosity",
  134. "-v",
  135. type=str.upper,
  136. default="INFO",
  137. metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
  138. help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
  139. )
  140. return parser.parse_args()
  141. def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
  142. if not args:
  143. # we allow for args to be passed externally, else we parse them ourselves
  144. args = parse_eval_args()
  145. eval_logger = utils.eval_logger
  146. eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
  147. eval_logger.info(f"Verbosity set to {args.verbosity}")
  148. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  149. initialize_tasks(args.verbosity)
  150. if args.limit:
  151. eval_logger.warning(
  152. " --limit SHOULD ONLY BE USED FOR TESTING."
  153. "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
  154. )
  155. if args.include_path is not None:
  156. eval_logger.info(f"Including path: {args.include_path}")
  157. include_path(args.include_path)
  158. if args.tasks is None:
  159. task_names = ALL_TASKS
  160. elif args.tasks == "list":
  161. eval_logger.info(
  162. "Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS)))
  163. )
  164. sys.exit()
  165. else:
  166. if os.path.isdir(args.tasks):
  167. import glob
  168. task_names = []
  169. yaml_path = os.path.join(args.tasks, "*.yaml")
  170. for yaml_file in glob.glob(yaml_path):
  171. config = utils.load_yaml_config(yaml_file)
  172. task_names.append(config)
  173. else:
  174. tasks_list = args.tasks.split(",")
  175. task_names = utils.pattern_match(tasks_list, ALL_TASKS)
  176. for task in [task for task in tasks_list if task not in task_names]:
  177. if os.path.isfile(task):
  178. config = utils.load_yaml_config(task)
  179. task_names.append(config)
  180. task_missing = [
  181. task
  182. for task in tasks_list
  183. if task not in task_names and "*" not in task
  184. ] # we don't want errors if a wildcard ("*") task name was used
  185. if task_missing:
  186. missing = ", ".join(task_missing)
  187. eval_logger.error(
  188. f"Tasks were not found: {missing}\n"
  189. f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
  190. )
  191. raise ValueError(
  192. f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
  193. )
  194. if args.output_path:
  195. path = Path(args.output_path)
  196. # check if file or 'dir/results.json' exists
  197. if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
  198. eval_logger.warning(
  199. f"File already exists at {path}. Results will be overwritten."
  200. )
  201. output_path_file = path.joinpath("results.json")
  202. assert not path.is_file(), "File already exists"
  203. # if path json then get parent dir
  204. elif path.suffix in (".json", ".jsonl"):
  205. output_path_file = path
  206. path.parent.mkdir(parents=True, exist_ok=True)
  207. path = path.parent
  208. else:
  209. path.mkdir(parents=True, exist_ok=True)
  210. output_path_file = path.joinpath("results.json")
  211. elif args.log_samples and not args.output_path:
  212. assert args.output_path, "Specify --output_path"
  213. eval_logger.info(f"Selected Tasks: {task_names}")
  214. print(f"type of model args: {type(args.model_args)}")
  215. print("*************************************")
  216. results = evaluator.simple_evaluate(
  217. model=args.model,
  218. model_args=args.model_args,
  219. tasks=task_names,
  220. num_fewshot=args.num_fewshot,
  221. batch_size=args.batch_size,
  222. max_batch_size=args.max_batch_size,
  223. device=args.device,
  224. use_cache=args.use_cache,
  225. limit=args.limit,
  226. decontamination_ngrams_path=args.decontamination_ngrams_path,
  227. check_integrity=args.check_integrity,
  228. write_out=args.write_out,
  229. log_samples=args.log_samples,
  230. gen_kwargs=args.gen_kwargs,
  231. )
  232. if results is not None:
  233. if args.log_samples:
  234. samples = results.pop("samples")
  235. dumped = json.dumps(
  236. results, indent=2, default=_handle_non_serializable, ensure_ascii=False
  237. )
  238. if args.show_config:
  239. print(dumped)
  240. batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
  241. if args.output_path:
  242. output_path_file.open("w").write(dumped)
  243. if args.log_samples:
  244. for task_name, config in results["configs"].items():
  245. output_name = "{}_{}".format(
  246. re.sub("/|=", "__", args.model_args), task_name
  247. )
  248. filename = path.joinpath(f"{output_name}.jsonl")
  249. samples_dumped = json.dumps(
  250. samples[task_name],
  251. indent=2,
  252. default=_handle_non_serializable,
  253. ensure_ascii=False,
  254. )
  255. filename.write_text(samples_dumped, encoding="utf-8")
  256. print(
  257. f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
  258. f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
  259. )
  260. print(make_table(results))
  261. if "groups" in results:
  262. print(make_table(results, "groups"))
  263. if __name__ == "__main__":
  264. cli_evaluate()