{ "cells": [ { "cell_type": "markdown", "id": "f0fad9db618f47a2", "metadata": { "id": "f0fad9db618f47a2" }, "source": [ "# Tutorial AAMAS 2026\n", "\n", "Welcome to the hands-on session! Here's how this notebook works:\n", "\n", "1. **Predict.** Before each result we'll commit a guess via a small poll.\n", "2. **Run.** Execute the cell and see whether the algorithm agrees with you.\n", "3. **Play.** Sliders let you change the number of traces to learn from, and watch the metrics react.\n", "4. **Compete.** We end with a leaderboard challenge — best solver wins.\n", "\n", "Tip: in Colab/Jupyter, if you want to run all cells at once, click `Cell → Run All` only *after* you've made your predictions in each section." ] }, { "cell_type": "code", "execution_count": null, "id": "bb6b76bb1fcd440d", "metadata": { "id": "bb6b76bb1fcd440d" }, "outputs": [], "source": [ "%pip install amlgym ipywidgets matplotlib pandas > /dev/null 2>&1" ] }, { "cell_type": "code", "execution_count": null, "id": "2556171f2c214c04", "metadata": { "id": "2556171f2c214c04" }, "outputs": [], "source": [ "# Shared utilities used throughout the notebook\n", "import re\n", "import difflib\n", "import ipywidgets as widgets\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "from IPython.display import display, HTML, Markdown\n", "\n", "# ggplot style for every plot\n", "plt.style.use(\"ggplot\")\n", "\n", "# ---------- polls & reveals ----------\n", "_POLL_HTML = \"\"\"
\n", "
{header}
\n", "
{question}
\"\"\"\n", "\n", "def make_poll(question, options, header=\"\\U0001F52E Cast your prediction\"):\n", " \"\"\"Render a styled poll. Returns the widget so we can read .value later.\"\"\"\n", " display(HTML(_POLL_HTML.format(header=header, question=question)))\n", " w = widgets.RadioButtons(options=options, layout=widgets.Layout(width=\"auto\"))\n", " display(w)\n", " return w\n", "\n", "def reveal(poll, truth, formatter=str):\n", " \"\"\"Reveal whether the audience's guess matches the computed truth.\"\"\"\n", " chosen = poll.value\n", " if chosen == truth:\n", " accent, bg, border = \"#15803d\", \"#dcfce7\", \"#16a34a\"\n", " icon, msg = \"\\U0001F31F\", \"Spot on!\"\n", " else:\n", " accent, bg, border = \"#92400e\", \"#fef3c7\", \"#d97706\"\n", " icon, msg = \"\\U0001F9ED\", \"Off the mark - but now you know.\"\n", " truth_str = formatter(truth) if formatter is not str else truth\n", " display(HTML(\n", " \"
\"\n", " \"
\"\n", " + icon + \" \" + msg + \"
\"\n", " \"
Your guess: \" + str(chosen) + \"
\"\n", " \"
Truth: \" + str(truth_str) + \"
\"\n", " ))\n", "\n", "def pick_bucket(value, options):\n", " \"\"\"Return the poll option whose numeric range contains ``value``.\"\"\"\n", " for opt in options:\n", " if re.search(r\"\\bexactly\\s+1\", opt, re.I) and value >= 1.0 - 1e-9:\n", " return opt\n", " if re.search(r\"\\bbelow\\b\", opt, re.I):\n", " nums = re.findall(r\"\\d+(?:\\.\\d+)?\", opt)\n", " if nums and value < float(nums[0]):\n", " return opt\n", " nums = re.findall(r\"\\d+(?:\\.\\d+)?\", opt)\n", " if len(nums) >= 2:\n", " lo, hi = float(nums[0]), float(nums[1])\n", " if lo <= value < hi:\n", " return opt\n", " if len(nums) == 1 and \"perfect\" in opt.lower() and value >= float(nums[0]) - 1e-9:\n", " return opt\n", " return options[-1]\n", "\n", "\n", "def _tokenize(text):\n", " text = re.sub(r\";[^\\n]*\", \"\", text)\n", " return re.findall(r\"\\(|\\)|[^\\s()]+\", text)\n", "\n", "def _parse_all(tokens):\n", " def helper(i):\n", " if tokens[i] == \"(\":\n", " out, i = [], i + 1\n", " while i < len(tokens) and tokens[i] != \")\":\n", " child, i = helper(i)\n", " out.append(child)\n", " return out, i + 1\n", " return tokens[i], i + 1\n", " forms, i = [], 0\n", " while i < len(tokens):\n", " f, i = helper(i)\n", " forms.append(f)\n", " return forms\n", "\n", "def _flat(node):\n", " if isinstance(node, str):\n", " return node\n", " return \"(\" + \" \".join(_flat(c) for c in node) + \")\"\n", "\n", "def _canon_var_list(items):\n", " \"\"\"Rename ?vars left-to-right to ?p1, ?p2, ... ; return (new_items, map).\"\"\"\n", " out, m, idx, j = [], {}, 0, 0\n", " while j < len(items):\n", " tok = items[j]\n", " if isinstance(tok, str) and tok.startswith(\"?\"):\n", " idx += 1\n", " new = \"?p{}\".format(idx)\n", " m[tok] = new\n", " out.append(new)\n", " j += 1\n", " if j + 1 < len(items) and items[j] == \"-\":\n", " out.extend([\"-\", items[j + 1]])\n", " j += 2\n", " else:\n", " out.append(tok)\n", " j += 1\n", " return out, m\n", "\n", "def _canon_predicate_decl(pred):\n", " if not (isinstance(pred, list) and pred):\n", " return pred\n", " new_args, _ = _canon_var_list(pred[1:])\n", " return [pred[0]] + new_args\n", "\n", "def _sort_and(node):\n", " if isinstance(node, list) and node and isinstance(node[0], str) and node[0] == \"and\":\n", " kids = [_sort_and(c) for c in node[1:]]\n", " kids.sort(key=_flat)\n", " return [\"and\"] + kids\n", " if isinstance(node, list):\n", " return [_sort_and(c) if isinstance(c, list) else c for c in node]\n", " return node\n", "\n", "def _canon_action(action):\n", " if not (isinstance(action, list) and action and action[0] == \":action\"):\n", " return action\n", " name = action[1]\n", " kv, i = {}, 2\n", " while i < len(action) - 1:\n", " kv[action[i]] = action[i + 1]\n", " i += 2\n", " canon_params, params_map = [], {}\n", " if \":parameters\" in kv and isinstance(kv[\":parameters\"], list):\n", " canon_params, params_map = _canon_var_list(kv[\":parameters\"])\n", " def rename(n):\n", " if isinstance(n, str):\n", " return params_map.get(n, n)\n", " return [rename(c) for c in n]\n", " out = [\":action\", name]\n", " if \":parameters\" in kv:\n", " out += [\":parameters\", canon_params]\n", " if \":precondition\" in kv:\n", " out += [\":precondition\", _sort_and(rename(kv[\":precondition\"]))]\n", " if \":effect\" in kv:\n", " out += [\":effect\", _sort_and(rename(kv[\":effect\"]))]\n", " return out\n", "\n", "def _canon_predicates_block(block):\n", " if not (isinstance(block, list) and block and block[0] == \":predicates\"):\n", " return block\n", " preds = [_canon_predicate_decl(p) for p in block[1:]]\n", " preds.sort(key=_flat)\n", " return [\":predicates\"] + preds\n", "\n", "def _render(node, indent=0):\n", " if isinstance(node, str):\n", " return node\n", " if not node:\n", " return \"()\"\n", " head = node[0]\n", " if isinstance(head, str) and head == \"and\" and len(node) > 1:\n", " sub_ind = \" \" * (indent + 1)\n", " body = \"\\n\".join(sub_ind + _render(c, indent + 1) for c in node[1:])\n", " return \"(and\\n\" + body + \")\"\n", " if isinstance(head, str) and head == \":action\":\n", " ind = \" \" * (indent + 1)\n", " out = \"(:action \" + str(node[1])\n", " i = 2\n", " while i < len(node) - 1:\n", " kw, val = node[i], node[i + 1]\n", " val_str = _render(val, indent + 2)\n", " if \"\\n\" in val_str:\n", " out += \"\\n\" + ind + kw + \"\\n\" + ind + \" \" + val_str.replace(\"\\n\", \"\\n \")\n", " else:\n", " out += \"\\n\" + ind + kw + \" \" + val_str\n", " i += 2\n", " return out + \")\"\n", " if isinstance(head, str) and head == \":predicates\":\n", " ind = \" \" * (indent + 1)\n", " body = \"\\n\".join(ind + _render(p, indent + 1) for p in node[1:])\n", " return \"(:predicates\\n\" + body + \")\"\n", " parts = [_render(c, indent + 1) if isinstance(c, list) else c for c in node]\n", " return \"(\" + \" \".join(parts) + \")\"\n", "\n", "def pddl_canonical(text):\n", " \"\"\"Return a canonical, diff-friendly PDDL string.\"\"\"\n", " try:\n", " forms = _parse_all(_tokenize(text.lower()))\n", " except Exception:\n", " return text\n", " lines = []\n", " for form in forms:\n", " if not isinstance(form, list):\n", " continue\n", " if form and form[0] == \"define\":\n", " lines.append(\"(define\")\n", " non_actions, actions = [], []\n", " for elem in form[1:]:\n", " if isinstance(elem, list) and elem and elem[0] == \":action\":\n", " actions.append(_canon_action(elem))\n", " elif isinstance(elem, list) and elem and elem[0] == \":predicates\":\n", " non_actions.append(_canon_predicates_block(elem))\n", " else:\n", " non_actions.append(elem)\n", " actions.sort(key=lambda a: a[1] if len(a) > 1 else \"\")\n", " for elem in non_actions:\n", " rendered = _render(elem, 1)\n", " lines.append(\" \" + rendered.replace(\"\\n\", \"\\n \"))\n", " for elem in actions:\n", " rendered = _render(elem, 1)\n", " lines.append(\"\")\n", " lines.append(\" \" + rendered.replace(\"\\n\", \"\\n \"))\n", " lines.append(\")\")\n", " else:\n", " lines.append(_render(form))\n", " return \"\\n\".join(lines) + \"\\n\"\n", "\n", "_DIFF_CSS = \"\"\"\"\"\"\n", "\n", "def html_diff(learned_path, ref_path, title_l=\"Learned\", title_r=\"Reference\"):\n", " \"\"\"Side-by-side, syntax-aware diff of two PDDL files.\"\"\"\n", " with open(learned_path) as f1, open(ref_path) as f2:\n", " l1 = pddl_canonical(f1.read()).splitlines(keepends=True)\n", " l2 = pddl_canonical(f2.read()).splitlines(keepends=True)\n", " table = difflib.HtmlDiff(wrapcolumn=80).make_table(\n", " l1, l2, fromdesc=title_l, todesc=title_r,\n", " context=False,\n", " )\n", " display(HTML(_DIFF_CSS + table))\n", "\n", "import re\n", "from html import escape\n", "from itertools import zip_longest\n", "from IPython.display import display, HTML\n", "\n", "# ---- styling: same shape as html_diff, no diff colors, dark readable text ----\n", "_SXS_CSS = \"\"\"\"\"\"\n", "\n", "# Color rules — tuned for white background, GitHub-light palette\n", "_PDDL_LOGIC = {\"and\", \"or\", \"not\", \"when\", \"forall\", \"exists\", \"imply\"}\n", "_PDDL_DEFINE = {\"define\", \"domain\", \"problem\"}\n", "\n", "def _pddl_highlight(line):\n", " \"\"\"Tokenize a PDDL line and wrap each token in a colored span.\"\"\"\n", " if not line:\n", " return \" \"\n", " out, i, next_is_head = [], 0, False\n", " while i < len(line):\n", " ch = line[i]\n", " # comment to end of line\n", " if ch == \";\":\n", " out.append(f\"\"\n", " f\"{escape(line[i:])}\")\n", " break\n", " # whitespace preserved verbatim\n", " if ch in \" \\t\":\n", " j = i\n", " while j < len(line) and line[j] in \" \\t\":\n", " j += 1\n", " out.append(line[i:j])\n", " i = j\n", " continue\n", " # parens (muted)\n", " if ch == \"(\":\n", " out.append(\"(\")\n", " i += 1; next_is_head = True\n", " continue\n", " if ch == \")\":\n", " out.append(\")\")\n", " i += 1\n", " continue\n", " # any other token\n", " j = i\n", " while j < len(line) and line[j] not in \" \\t()\":\n", " j += 1\n", " tok = line[i:j]; i = j\n", " low = tok.lower()\n", " if tok.startswith(\"?\"): # variable\n", " style = \"color:#0969da\"\n", " elif tok.startswith(\":\"): # :keyword (incl. :strips, :typing, :action ...)\n", " style = \"color:#8250df;font-weight:600\"\n", " elif low in _PDDL_LOGIC: # and, not, when, ...\n", " style = \"color:#cf222e;font-weight:600\"\n", " elif low in _PDDL_DEFINE: # define, domain, problem\n", " style = \"color:#8250df;font-weight:700\"\n", " elif tok == \"-\": # type separator\n", " style = \"color:#8b949e\"\n", " elif next_is_head: # predicate / action / operator head\n", " style = \"color:#953800;font-weight:500\"\n", " else: # objects, types — default dark slate\n", " style = \"color:#1f2937\"\n", " out.append(f\"{escape(tok)}\")\n", " next_is_head = False\n", " return \"\".join(out) or \" \"\n", "\n", "def html_side_by_side(left_path, right_path,\n", " title_l=\"Learned\", title_r=\"Reference\"):\n", " \"\"\"Print two PDDL files side by side with syntax highlighting (no diff).\"\"\"\n", " with open(left_path) as f1: l = pddl_canonical(f1.read()).splitlines()\n", " with open(right_path) as f2: r = pddl_canonical(f2.read()).splitlines()\n", " rows = []\n", " for n, (a, b) in enumerate(zip_longest(l, r, fillvalue=\"\"), start=1):\n", " rows.append(\n", " \"\"\n", " f\"{n}{_pddl_highlight(a)}\"\n", " f\"{n}{_pddl_highlight(b)}\"\n", " \"\"\n", " )\n", " table = (\n", " \"\"\n", " \"\"\n", " \"\"\n", " f\"\"\n", " f\"\"\n", " \"\" + \"\".join(rows) + \"
{escape(title_l)}{escape(title_r)}
\"\n", " )\n", " display(HTML(_SXS_CSS + table))\n", "\n", "# ---------- plotting helper ----------\n", "_GG_COLORS = [\"#e24a33\", \"#348abd\", \"#988ed5\", \"#fbc15e\", \"#777777\", \"#8eba42\"]\n", "\n", "def bar_compare(values, title, ylabel, ylim=(0, 1)):\n", " with plt.style.context(\"ggplot\"):\n", " fig, ax = plt.subplots(figsize=(6, 3.5))\n", " labels = list(values.keys())\n", " ax.bar(labels, [values[k] for k in labels],\n", " color=_GG_COLORS[: len(labels)])\n", " ax.set_title(title)\n", " ax.set_ylabel(ylabel)\n", " if ylim:\n", " ax.set_ylim(*ylim)\n", " for i, v in enumerate(values.values()):\n", " ax.text(i, v + 0.02, \"{:.2f}\".format(v), ha=\"center\", fontsize=10,\n", " color=\"#333333\")\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Silence the unified-planning credits banner\n", "import unified_planning\n", "from unified_planning import shortcuts\n", "shortcuts.get_environment().credits_stream = None\n" ] }, { "cell_type": "markdown", "id": "b2493ad279574fc5", "metadata": { "id": "b2493ad279574fc5" }, "source": [ "## 1. Exploring the benchmarks\n", "\n", "AMLGym ships dozens of IPC domains. Let's first see what's available." ] }, { "cell_type": "code", "execution_count": null, "id": "d25d9ed865d34368", "metadata": { "id": "d25d9ed865d34368" }, "outputs": [], "source": [ "from amlgym.benchmarks import print_domains\n", "print_domains()" ] }, { "cell_type": "markdown", "id": "bac2b819bd4048b8", "metadata": { "id": "bac2b819bd4048b8" }, "source": [ "Pick a domain to inspect: let's start with **blocksworld**." ] }, { "cell_type": "code", "execution_count": null, "id": "d5f58cdabc5d48d9", "metadata": { "id": "d5f58cdabc5d48d9" }, "outputs": [], "source": [ "from amlgym.benchmarks import get_domain_path, get_domain\n", "domain_path = get_domain_path('blocksworld')\n", "print(domain_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "e70edb55cf284e80", "metadata": { "id": "e70edb55cf284e80" }, "outputs": [], "source": [ "domain_pddl = get_domain('blocksworld')\n", "print(domain_pddl)" ] }, { "cell_type": "markdown", "id": "8b92c56fb2ac4a5a", "metadata": { "id": "8b92c56fb2ac4a5a" }, "source": [ "Each domain comes with **10 training trajectories** generated from a known reference model. These are what the learners get to see." ] }, { "cell_type": "code", "execution_count": null, "id": "35150e4374e54ea7", "metadata": { "id": "35150e4374e54ea7" }, "outputs": [], "source": [ "from amlgym.benchmarks import get_trajectories_path, get_trajectories, get_problems_path\n", "from pprint import pprint\n", "\n", "trajectory_paths = get_trajectories_path('blocksworld')\n", "pprint(trajectory_paths)" ] }, { "cell_type": "code", "execution_count": null, "id": "c5815dbea29f471f", "metadata": { "id": "c5815dbea29f471f" }, "outputs": [], "source": [ "trajectory = get_trajectories('blocksworld')[0]\n", "print(trajectory)" ] }, { "cell_type": "code", "execution_count": null, "id": "29b854de32d24bd8", "metadata": { "id": "29b854de32d24bd8" }, "outputs": [], "source": [ "problem_path = get_problems_path('blocksworld', kind=\"learning\")[0]\n", "with open(problem_path) as f:\n", " print(f.read())" ] }, { "cell_type": "markdown", "id": "314ff5cec70c4578", "metadata": { "id": "314ff5cec70c4578" }, "source": [ "## 2. Passive learning with full observability — SAM\n", "\n", "AMLGym registers several passive learners. Each one assumes something different\n", "about the trajectories: full observability, partial observability, or noisy state observations." ] }, { "cell_type": "code", "execution_count": null, "id": "0f6256c113d9429f", "metadata": { "id": "0f6256c113d9429f" }, "outputs": [], "source": [ "from amlgym.algorithms import print_algorithms\n", "print_algorithms()" ] }, { "cell_type": "markdown", "id": "82f42f681a614423", "metadata": { "id": "82f42f681a614423" }, "source": [ "### 🔮 Prediction: before we learn\n", "\n", "With **10 full-observability trajectories** of blocksworld, what do you think SAM's *syntactic precision* will be?" ] }, { "cell_type": "code", "execution_count": null, "id": "a3a933da110a49f1", "metadata": { "id": "a3a933da110a49f1" }, "outputs": [], "source": [ "poll_sam_precision = make_poll(\n", " \"Pick the range you'd bet on:\",\n", " [\"Below 0.5 (a lot of spurious preconditions)\",\n", " \"0.5 – 0.8 (mostly right but some noise)\",\n", " \"0.8 – 0.99 (almost perfect)\",\n", " \"Exactly 1.0 (SAM is safe by construction!)\"],\n", ")" ] }, { "cell_type": "markdown", "id": "2d15dcd66bd044a6", "metadata": { "id": "2d15dcd66bd044a6" }, "source": [ "### Learn the model with SAM" ] }, { "cell_type": "code", "execution_count": null, "id": "bc70cacfcd7e46b6", "metadata": { "id": "bc70cacfcd7e46b6" }, "outputs": [], "source": [ "from amlgym.algorithms import get_algorithm\n", "from amlgym.util.util import empty_domain\n", "\n", "sam = get_algorithm('sam')\n", "domain_path = get_domain_path('blocksworld')\n", "domain_empty_path = empty_domain(domain_path)\n", "# with open(domain_empty_path, 'r') as f:\n", "# pprint(f.read())\n", "\n", "traj_paths = get_trajectories_path('blocksworld')\n", "\n", "model_sam = sam.learn(domain_empty_path, traj_paths)\n", "\n", "domain_sam_path = 'sam_blocksworld.pddl'\n", "with open(domain_sam_path, 'w') as f:\n", " f.write(model_sam)\n", "print('Saved learned model to', domain_sam_path)" ] }, { "cell_type": "markdown", "id": "7829c2e5d1b648ba", "metadata": { "id": "7829c2e5d1b648ba" }, "source": [ "### Evaluate: syntactic precision and recall" ] }, { "cell_type": "code", "execution_count": null, "id": "52f78f08812249df", "metadata": { "id": "52f78f08812249df" }, "outputs": [], "source": [ "from amlgym.metrics import syntactic_precision, syntactic_recall\n", "\n", "domain_ref_path = get_domain_path('blocksworld')\n", "prec = syntactic_precision(domain_sam_path, domain_ref_path)\n", "rec = syntactic_recall (domain_sam_path, domain_ref_path)\n", "\n", "print(\"Recall:\");\n", "pprint(rec)" ] }, { "cell_type": "code", "execution_count": null, "id": "0ef8560e470a4644", "metadata": { "id": "0ef8560e470a4644" }, "outputs": [], "source": [ "reveal(\n", " poll_sam_precision,\n", " pick_bucket(prec[\"mean\"], poll_sam_precision.options),\n", " formatter=lambda b: f\"{b} (mean precision = {prec['mean']:.3f})\",\n", ")\n" ] }, { "cell_type": "markdown", "id": "eb8118184c67430e", "metadata": { "id": "eb8118184c67430e" }, "source": [ "### 🔍 Diff: learned model vs. reference\n", "\n", "Numbers can be misleading — let's *see* what SAM actually learned." ] }, { "cell_type": "code", "execution_count": null, "id": "3d53579ffd414629", "metadata": { "id": "3d53579ffd414629" }, "outputs": [], "source": [ "html_side_by_side(domain_sam_path, domain_ref_path, \"SAM (learned)\", \"Reference\")" ] }, { "cell_type": "markdown", "id": "0c7c84126ddf42e1", "metadata": { "id": "0c7c84126ddf42e1" }, "source": [ "### 🔮 Prediction: problem solving\n", "\n", "Now we test the **learned** model efficacy for solving new planning problems." ] }, { "cell_type": "code", "execution_count": null, "id": "bbbace710cc4480a", "metadata": { "id": "bbbace710cc4480a" }, "outputs": [], "source": [ "poll_solve = make_poll(\n", " \"What ratio of test problems can the SAM-learned model solve?\",\n", " [\"Below 0.3\", \"0.3 – 0.7\", \"0.7 – 0.99\", \"1.0 (perfect)\"],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "aa89e7abac984bd2", "metadata": { "id": "aa89e7abac984bd2" }, "outputs": [], "source": [ "from amlgym.metrics import problem_solving\n", "\n", "probs_paths = get_problems_path('blocksworld', kind='solving')\n", "metrics_sam = problem_solving(domain_sam_path, domain_ref_path, probs_paths, timeout=60)" ] }, { "cell_type": "code", "execution_count": null, "id": "1b6e9561b71c417d", "metadata": { "id": "1b6e9561b71c417d" }, "outputs": [], "source": [ "solve_ratio = metrics_sam.get(\"solving_ratio\", metrics_sam.get(\"solving\", 0.0))\n", "print(f\"Solving ratio: {solve_ratio:.2f}\")\n", "\n", "reveal(\n", " poll_solve,\n", " pick_bucket(solve_ratio, poll_solve.options),\n", " formatter=lambda b: f\"{b}\",\n", ")\n" ] }, { "cell_type": "markdown", "id": "df8d0bd45736470f", "metadata": { "id": "df8d0bd45736470f" }, "source": [ "### Predictive power\n", "\n", "This metric goes beyond problem solving: on a test set of states,\n", "*does the learned model predict the same applicability and effects as the\n", "reference one?*" ] }, { "cell_type": "markdown", "id": "27da9c0a", "metadata": { "id": "27da9c0a" }, "source": [ "### 🔮 Prediction: predicted-effects precision\n", "\n", "Predictive power asks a tougher question than solving: for every test state, does the learned model give the *same predicted effects* as the reference?" ] }, { "cell_type": "code", "execution_count": null, "id": "87f0d9d1", "metadata": { "id": "87f0d9d1" }, "outputs": [], "source": [ "poll_predeff = make_poll(\n", " \"What do you think the mean precision on predicted effects will be for SAM on blocksworld?\",\n", " [\"Below 0.5 (often predicts wrong fluents)\",\n", " \"0.5 – 0.8 (mostly right, occasional spurious effects)\",\n", " \"0.8 – 0.99 (almost identical to the reference)\",\n", " \"Exactly 1.0 (every predicted effect is correct)\"],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "7daf67af13d447af", "metadata": { "id": "7daf67af13d447af" }, "outputs": [], "source": [ "from amlgym.benchmarks import get_test_states\n", "from amlgym.modeling.UPEnv import UPEnv\n", "from amlgym.metrics import predictive_power\n", "\n", "all_test_states = get_test_states('blocksworld')\n", "problem_paths = get_problems_path('blocksworld', kind='predictive_power')\n", "problem_path = problem_paths[0]\n", "test_states = all_test_states[problem_path.split('/')[-1]]\n", "\n", "simulator_learned = UPEnv(domain_sam_path, problem_path)\n", "simulator_ref = UPEnv(domain_ref_path, problem_path)\n", "\n", "predictive_metrics = predictive_power(simulator_learned, simulator_ref, test_states)\n", "pprint(predictive_metrics['applicability'])" ] }, { "cell_type": "code", "execution_count": null, "id": "45d4ad28", "metadata": { "id": "45d4ad28" }, "outputs": [], "source": [ "pred_eff_prec = float(predictive_metrics[\"predicted_effects\"][\"mean_precision\"])\n", "reveal(\n", " poll_predeff,\n", " pick_bucket(pred_eff_prec, poll_predeff.options),\n", " formatter=lambda b: f\"{b} (mean precision on predicted effects = {pred_eff_prec:.3f})\",\n", ")" ] }, { "cell_type": "markdown", "id": "465d0561c0f04bec", "metadata": { "id": "465d0561c0f04bec" }, "source": [ "## 3. Play with the number of learning traces\n", "\n", "Move the slider below to vary how many traces\n", "SAM gets to see, and watch precision and recall reaction.\n" ] }, { "cell_type": "markdown", "id": "bdfbc9ef88654a70", "metadata": { "id": "bdfbc9ef88654a70" }, "source": [ "### How many trajectories do you actually need?\n", "\n", "Does SAM learn the domain after 1 or 3 trajectories? Does it really need all 10?\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b7c1f6eee0054b10", "metadata": { "id": "b7c1f6eee0054b10" }, "outputs": [], "source": [ "from ipywidgets import interact_manual, IntSlider\n", "\n", "EXPLORE_DOMAIN = \"barman\"\n", "_dom_path = get_domain_path(EXPLORE_DOMAIN)\n", "_dom_empty = empty_domain(_dom_path)\n", "_trajs = get_trajectories_path(EXPLORE_DOMAIN)\n", "\n", "@interact_manual(\n", " n=IntSlider(min=1, max=10, step=1, value=3,\n", " description=\"# trajectories\")\n", ")\n", "def explore_n_traj(n):\n", " learner = get_algorithm(\"sam\")\n", " model = learner.learn(_dom_empty, _trajs[:n])\n", " out = f\"sam_n{n}.pddl\"\n", " with open(out, \"w\") as f:\n", " f.write(model)\n", " p = syntactic_precision(out, _dom_path)\n", " r = syntactic_recall (out, _dom_path)\n", " p_overall = p[\"mean\"] if isinstance(p, dict) and \"mean\" in p else (\n", " sum(p.values()) / max(len(p), 1) if isinstance(p, dict) else float(p))\n", " r_overall = r[\"mean\"] if isinstance(r, dict) and \"mean\" in r else (\n", " sum(r.values()) / max(len(r), 1) if isinstance(r, dict) else float(r))\n", " bar_compare({\"precision\": p_overall, \"recall\": r_overall},\n", " title=f\"SAM on {EXPLORE_DOMAIN} with {n} trajectories\",\n", " ylabel=\"score\")\n" ] }, { "cell_type": "markdown", "id": "407f2129883d4cbb", "metadata": { "id": "407f2129883d4cbb" }, "source": [ "## 4. Question A: NOLAM vs. OffLAM on **tpp** (syntactic precision)\n", "\n", "Both learners see the same noiseless full trajectories from the `tpp` domain.\n", "**Who wins on syntactic precision?**\n", "\n", "NOLAM is built for noisy traces but we set `noise=0`. OffLAM was designed for partial observability. Who will be most influenced by the adopted assumptions?" ] }, { "cell_type": "code", "execution_count": null, "id": "ad1b5fd59a214e02", "metadata": { "id": "ad1b5fd59a214e02" }, "outputs": [], "source": [ "poll_qa = make_poll(\n", " \"Who do you bet on?\",\n", " [\"OffLAM wins\", \"NOLAM wins\", \"It's a tie\"],\n", ")" ] }, { "cell_type": "markdown", "id": "12d96e0637be48a2", "metadata": { "id": "12d96e0637be48a2" }, "source": [ "Now let's run both and see." ] }, { "cell_type": "code", "execution_count": null, "id": "47fb54fe4769482d", "metadata": { "id": "47fb54fe4769482d" }, "outputs": [], "source": [ "offlam = get_algorithm('offlam')\n", "nolam = get_algorithm('nolam', noise=0.0)\n", "\n", "domain_path = get_domain_path('tpp')\n", "domain_empty_path = empty_domain(domain_path)\n", "traj_paths = get_trajectories_path('tpp')\n", "\n", "domain_offlam = offlam.learn(domain_empty_path, traj_paths)\n", "domain_nolam = nolam.learn (domain_empty_path, traj_paths)\n", "\n", "domain_offlam_path = 'offlam_tpp.pddl'\n", "domain_nolam_path = 'nolam_tpp.pddl'\n", "with open(domain_offlam_path, 'w') as f: f.write(domain_offlam)\n", "with open(domain_nolam_path , 'w') as f: f.write(domain_nolam)\n", "print('Models saved.')" ] }, { "cell_type": "code", "execution_count": null, "id": "8dc0fb5a19cb4917", "metadata": { "id": "8dc0fb5a19cb4917" }, "outputs": [], "source": [ "p_off = syntactic_precision(domain_offlam_path, domain_path)['mean']\n", "p_nol = syntactic_precision(domain_nolam_path , domain_path)['mean']\n", "r_off = syntactic_recall(domain_offlam_path, domain_path)['mean']\n", "r_nol = syntactic_recall(domain_nolam_path , domain_path)['mean']\n", "\n", "bar_compare({'OffLAM': p_off, 'NOLAM': p_nol},\n", " title='Question A — syntactic precision on tpp', ylabel='precision')\n", "bar_compare({'OffLAM': r_off, 'NOLAM': r_nol},\n", " title='syntactic recall on tpp', ylabel='recall')\n", "\n", "if p_off > p_nol + 1e-6: winner_a = \"OffLAM wins\"\n", "elif p_nol > p_off + 1e-6: winner_a = \"NOLAM wins\"\n", "else: winner_a = \"It's a tie\"\n", "\n", "reveal(poll_qa, winner_a,\n", " formatter=lambda w: f\"{w} (OffLAM={p_off:.3f}, NOLAM={p_nol:.3f})\")" ] }, { "cell_type": "markdown", "id": "9855ec66abf94c00", "metadata": { "id": "9855ec66abf94c00" }, "source": [ "#### 🔍 Where did they disagree? Diff each learned model against the reference." ] }, { "cell_type": "code", "execution_count": null, "id": "502cbc4f67124521", "metadata": { "id": "502cbc4f67124521" }, "outputs": [], "source": [ "print('--- OffLAM vs reference ---')\n", "html_side_by_side(domain_offlam_path, domain_path, 'OffLAM', 'Reference')" ] }, { "cell_type": "code", "execution_count": null, "id": "ba2175d71a0c4d4a", "metadata": { "id": "ba2175d71a0c4d4a" }, "outputs": [], "source": [ "print('--- NOLAM vs reference ---')\n", "html_side_by_side(domain_nolam_path, domain_path, 'NOLAM', 'Reference')" ] }, { "cell_type": "markdown", "id": "05b058a6994a4fe4", "metadata": { "id": "05b058a6994a4fe4" }, "source": [ "## 5. Question B: SAM vs. OffLAM on **goldminer** (solving ratio)\n", "\n", "Same trajectories, different question. **Who produces a model that solves more\n", "planning problems?**" ] }, { "cell_type": "code", "execution_count": null, "id": "e6f8315cd2dd4321", "metadata": { "id": "e6f8315cd2dd4321" }, "outputs": [], "source": [ "poll_qb = make_poll(\n", " \"Who do you bet on?\",\n", " [\"SAM wins\", \"OffLAM wins\", \"It's a tie\"],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "2f9a69af2cd4449b", "metadata": { "id": "2f9a69af2cd4449b" }, "outputs": [], "source": [ "sam_b = get_algorithm('sam')\n", "offlam_b = get_algorithm('offlam')\n", "\n", "domain_path_b = get_domain_path('goldminer')\n", "domain_empty_b = empty_domain(domain_path_b)\n", "traj_paths_b = get_trajectories_path('goldminer')\n", "\n", "model_sam_b = sam_b .learn(domain_empty_b, traj_paths_b)\n", "model_offlam_b = offlam_b.learn(domain_empty_b, traj_paths_b)\n", "\n", "sam_path_b = 'sam_goldminer.pddl'\n", "offlam_path_b = 'offlam_goldminer.pddl'\n", "with open(sam_path_b , 'w') as f: f.write(model_sam_b)\n", "with open(offlam_path_b, 'w') as f: f.write(model_offlam_b)" ] }, { "cell_type": "code", "execution_count": null, "id": "8faaacb6579541df", "metadata": { "id": "8faaacb6579541df" }, "outputs": [], "source": [ "probs_b = get_problems_path('goldminer', kind='solving')\n", "\n", "m_sam = problem_solving(sam_path_b , domain_path_b, probs_b, timeout=60, show_progress=False)\n", "m_offlam = problem_solving(offlam_path_b, domain_path_b, probs_b, timeout=60, show_progress=False)\n", "\n", "s_sam = m_sam.get('solving_ratio', m_sam .get('solving', 0.0))\n", "s_offlam = m_offlam.get('solving_ratio', m_offlam.get('solving', 0.0))\n", "\n", "bar_compare({'SAM': s_sam, 'OffLAM': s_offlam},\n", " title='Question B — solving ratio on goldminer', ylabel='Solving ratio')\n", "\n", "if s_sam > s_offlam: winner_b = \"SAM wins\"\n", "elif s_offlam > s_sam: winner_b = \"OffLAM wins\"\n", "else: winner_b = \"It's a tie\"\n", "\n", "reveal(poll_qb, winner_b,\n", " formatter=lambda w: f\"{w} (SAM={s_sam:.3f}, OffLAM={s_offlam:.3f})\")" ] }, { "cell_type": "code", "source": [ "html_side_by_side(offlam_path_b, domain_path_b, 'OffLAM', 'Reference')" ], "metadata": { "id": "EbkM1lyCpyxi" }, "id": "EbkM1lyCpyxi", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "html_side_by_side(sam_path_b, domain_path_b, 'SAM', 'Reference')" ], "metadata": { "id": "NgJYHsyeqJ4X" }, "id": "NgJYHsyeqJ4X", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "id": "ea70416b59d34899", "metadata": { "id": "ea70416b59d34899" }, "source": [ "## 6. Final challenge: the leaderboard 🏆\n", "\n", "Pick a **domain**, pick an **algorithm**, and\n", "submit your run. Your **solving ratio** on the test problems is your score.\n", "\n", "Submissions accumulate in the leaderboard below. Whoever's at the top when the\n", "session ends wins.\n", "\n", "*Tip:* try a domain you haven't seen yet — `parking`, ..." ] }, { "cell_type": "code", "execution_count": null, "id": "784c9d7d45374ea9", "metadata": { "id": "784c9d7d45374ea9" }, "outputs": [], "source": [ "from amlgym.benchmarks import get_domain_names\n", "\n", "ALL_DOMAINS = sorted(get_domain_names())\n", "ALL_ALGORITHMS = [\"sam\", \"offlam\", \"nolam\", \"rosame\"]\n", "\n", "leaderboard = [] # accumulating list of dicts\n", "\n", "def _score(algo_name, domain_name):\n", " learner = get_algorithm(algo_name)\n", " dom_path = get_domain_path(domain_name)\n", " dom_empty = empty_domain(dom_path)\n", " trajs = get_trajectories_path(domain_name)\n", " model = learner.learn(dom_empty, trajs)\n", " out_path = f\"leaderboard_{algo_name}_{domain_name}.pddl\"\n", " with open(out_path, \"w\") as f:\n", " f.write(model)\n", "\n", " print(f\"Computing problem solving ratio of {algo_name} on {domain_name} ...\")\n", " probs = get_problems_path(domain_name, kind=\"solving\")\n", " m = problem_solving(out_path, dom_path, probs,\n", " timeout=3, show_progress=False)\n", " return m.get(\"solving_ratio\", m.get(\"solving\", 0.0)), m\n", "\n", "name_w = widgets.Text(value=\"\", placeholder=\"Your name / team\",\n", " description=\"Name:\")\n", "dom_w = widgets.Dropdown(options=ALL_DOMAINS, value=ALL_DOMAINS[0],\n", " description=\"Domain:\")\n", "algo_w = widgets.Dropdown(options=ALL_ALGORITHMS, value=\"sam\",\n", " description=\"Algorithm:\")\n", "go_btn = widgets.Button(description=\"🚀 Submit run\",\n", " button_style=\"success\")\n", "out_w = widgets.Output()\n", "\n", "def _on_click(_):\n", " out_w.clear_output()\n", " with out_w:\n", " nm = name_w.value.strip() or \"anonymous\"\n", " algo, dom = algo_w.value, dom_w.value\n", " print(f\"Running {algo} on {dom} ...\")\n", " try:\n", " s, _ = _score(algo, dom)\n", " except Exception as e:\n", " print(\"Run failed:\", e)\n", " return\n", " leaderboard.append({\n", " \"name\": nm, \"algorithm\": algo, \"domain\": dom,\n", " \"solving_ratio\": round(s, 3),\n", " })\n", " df = (pd.DataFrame(leaderboard)\n", " .sort_values(\"solving_ratio\", ascending=False)\n", " .reset_index(drop=True))\n", " df.index = df.index + 1\n", " display(Markdown(\"### 🏆 Leaderboard\"))\n", " display(df)\n", "\n", "go_btn.on_click(_on_click)\n", "display(widgets.VBox([name_w, dom_w, algo_w, go_btn, out_w]))\n" ] }, { "cell_type": "markdown", "id": "b1b44eca5ff649e7", "metadata": { "id": "b1b44eca5ff649e7" }, "source": [ "### Closing thoughts\n", "\n", "We have now seen the full cycle: pick a domain and associated training set of trajectories, pick a learning algorithm, learn a domain, evaluate the domain by measuring syntactic, solving, and predictive power metrics. Two ideas worth taking home:\n", "\n", "- **The right algorithm depends on the data you actually have.** SAM is powerful when traces are fully observable; OffLAM when they are partially observable; NOLAM when state observations are noisy.\n", "- **Syntactic, solving and predictive power metrics are complementary.** For example: a model can be syntactically imperfect but still solve every test problem. Evaluate against the task you care about.\n", "\n", "Thanks for playing!\n", "\n", "
\n", "\n", "
\n", " 📝 Tell us what you think:\n", " \n", " anonymous questionnaire\n", "
\n", " \n", " \"Feedback\n", "
" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10" } }, "nbformat": 4, "nbformat_minor": 5 }