{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 3: Single-View Geometry\n",
    "\n",
    "## Usage\n",
    "This code snippet provides an overall code structure and some interactive plot interfaces for the *Single-View Geometry* section of Assignment 3. In [main function](#Main-function), we outline the required functionalities step by step. Some of the functions which involves interactive plots are already provided, but [the rest](#Your-implementation) are left for you to implement.\n",
    "\n",
    "## Package installation\n",
    "- In this code, we use `tkinter` package. Installation instruction can be found [here](https://anaconda.org/anaconda/tk)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Common imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib tk\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Provided functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_input_lines(im, min_lines=3):\n",
    "    \"\"\"\n",
    "    Allows user to input line segments; computes centers and directions.\n",
    "    Inputs:\n",
    "        im: np.ndarray of shape (height, width, 3)\n",
    "        min_lines: minimum number of lines required\n",
    "    Returns:\n",
    "        n: number of lines from input\n",
    "        lines: np.ndarray of shape (3, n)\n",
    "            where each column denotes the parameters of the line equation\n",
    "        centers: np.ndarray of shape (3, n)\n",
    "            where each column denotes the homogeneous coordinates of the centers\n",
    "    \"\"\"\n",
    "    n = 0\n",
    "    lines = np.zeros((3, 0))\n",
    "    centers = np.zeros((3, 0))\n",
    "\n",
    "    plt.figure()\n",
    "    plt.imshow(im)\n",
    "    plt.show()\n",
    "    print('Set at least %d lines to compute vanishing point' % min_lines)\n",
    "    while True:\n",
    "        print('Click the two endpoints, use the right key to undo, and use the middle key to stop input')\n",
    "        clicked = plt.ginput(2, timeout=0, show_clicks=True)\n",
    "        if not clicked or len(clicked) < 2:\n",
    "            if n < min_lines:\n",
    "                print('Need at least %d lines, you have %d now' % (min_lines, n))\n",
    "                continue\n",
    "            else:\n",
    "                # Stop getting lines if number of lines is enough\n",
    "                break\n",
    "\n",
    "        # Unpack user inputs and save as homogeneous coordinates\n",
    "        pt1 = np.array([clicked[0][0], clicked[0][1], 1])\n",
    "        pt2 = np.array([clicked[1][0], clicked[1][1], 1])\n",
    "        # Get line equation using cross product\n",
    "        # Line equation: line[0] * x + line[1] * y + line[2] = 0\n",
    "        line = np.cross(pt1, pt2)\n",
    "        lines = np.append(lines, line.reshape((3, 1)), axis=1)\n",
    "        # Get center coordinate of the line segment\n",
    "        center = (pt1 + pt2) / 2\n",
    "        centers = np.append(centers, center.reshape((3, 1)), axis=1)\n",
    "\n",
    "        # Plot line segment\n",
    "        plt.plot([pt1[0], pt2[0]], [pt1[1], pt2[1]], color='b')\n",
    "\n",
    "        n += 1\n",
    "\n",
    "    return n, lines, centers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_lines_and_vp(im, lines, vp):\n",
    "    \"\"\"\n",
    "    Plots user-input lines and the calculated vanishing point.\n",
    "    Inputs:\n",
    "        im: np.ndarray of shape (height, width, 3)\n",
    "        lines: np.ndarray of shape (3, n)\n",
    "            where each column denotes the parameters of the line equation\n",
    "        vp: np.ndarray of shape (3, )\n",
    "    \"\"\"\n",
    "    bx1 = min(1, vp[0] / vp[2]) - 10\n",
    "    bx2 = max(im.shape[1], vp[0] / vp[2]) + 10\n",
    "    by1 = min(1, vp[1] / vp[2]) - 10\n",
    "    by2 = max(im.shape[0], vp[1] / vp[2]) + 10\n",
    "\n",
    "    plt.figure()\n",
    "    plt.imshow(im)\n",
    "    for i in range(lines.shape[1]):\n",
    "        if lines[0, i] < lines[1, i]:\n",
    "            pt1 = np.cross(np.array([1, 0, -bx1]), lines[:, i])\n",
    "            pt2 = np.cross(np.array([1, 0, -bx2]), lines[:, i])\n",
    "        else:\n",
    "            pt1 = np.cross(np.array([0, 1, -by1]), lines[:, i])\n",
    "            pt2 = np.cross(np.array([0, 1, -by2]), lines[:, i])\n",
    "        pt1 = pt1 / pt1[2]\n",
    "        pt2 = pt2 / pt2[2]\n",
    "        plt.plot([pt1[0], pt2[0]], [pt1[1], pt2[1]], 'g')\n",
    "\n",
    "    plt.plot(vp[0] / vp[2], vp[1] / vp[2], 'ro')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_top_and_bottom_coordinates(im, obj):\n",
    "    \"\"\"\n",
    "    For a specific object, prompts user to record the top coordinate and the bottom coordinate in the image.\n",
    "    Inputs:\n",
    "        im: np.ndarray of shape (height, width, 3)\n",
    "        obj: string, object name\n",
    "    Returns:\n",
    "        coord: np.ndarray of shape (3, 2)\n",
    "            where coord[:, 0] is the homogeneous coordinate of the top of the object and coord[:, 1] is the homogeneous\n",
    "            coordinate of the bottom\n",
    "    \"\"\"\n",
    "    plt.figure()\n",
    "    plt.imshow(im)\n",
    "\n",
    "    print('Click on the top coordinate of %s' % obj)\n",
    "    clicked = plt.ginput(1, timeout=0, show_clicks=True)\n",
    "    x1, y1 = clicked[0]\n",
    "    # Uncomment this line to enable a vertical line to help align the two coordinates\n",
    "    # plt.plot([x1, x1], [0, im.shape[0]], 'b')\n",
    "    print('Click on the bottom coordinate of %s' % obj)\n",
    "    clicked = plt.ginput(1, timeout=0, show_clicks=True)\n",
    "    x2, y2 = clicked[0]\n",
    "\n",
    "    plt.plot([x1, x2], [y1, y2], 'b')\n",
    "\n",
    "    return np.array([[x1, x2], [y1, y2], [1, 1]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Your implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_vanishing_point():\n",
    "    \"\"\"\n",
    "    Solves for the vanishing point using the user-input lines.\n",
    "    \"\"\"\n",
    "    # <YOUR IMPLEMENTATION>\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_horizon_line():\n",
    "    \"\"\"\n",
    "    Calculates the ground horizon line.\n",
    "    \"\"\"\n",
    "    # <YOUR IMPLEMENTATION>\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_horizon_line():\n",
    "    \"\"\"\n",
    "    Plots the horizon line.\n",
    "    \"\"\"\n",
    "    # <YOUR IMPLEMENTATION>\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_camera_parameters():\n",
    "    \"\"\"\n",
    "    Computes the camera parameters. Hint: The SymPy package is suitable for this.\n",
    "    \"\"\"\n",
    "    # <YOUR IMPLEMENTATION>\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rotation_matrix():\n",
    "    \"\"\"\n",
    "    Computes the rotation matrix using the camera parameters.\n",
    "    \"\"\"\n",
    "    # <YOUR IMPLEMENTATION>\n",
    "    return R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_height():\n",
    "    \"\"\"\n",
    "    Estimates height for a specific object using the recorded coordinates. You might need to plot additional images here for\n",
    "    your report.\n",
    "    \"\"\"\n",
    "    # <YOUR IMPLEMENTATION>\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "im = np.asarray(Image.open('CSL.jpeg'))\n",
    "\n",
    "# Part 1\n",
    "# Get vanishing points for each of the directions\n",
    "num_vpts = 3\n",
    "vpts = np.zeros((3, num_vpts))\n",
    "for i in range(num_vpts):\n",
    "    print('Getting vanishing point %d' % i)\n",
    "    # Get at least three lines from user input\n",
    "    n, lines, centers = get_input_lines(im)\n",
    "    # <YOUR IMPLEMENTATION> Solve for vanishing point\n",
    "    vpts[:, i] = get_vanishing_point()\n",
    "    # Plot the lines and the vanishing point\n",
    "    plot_lines_and_vp(im, lines, vpts[:, i])\n",
    "\n",
    "# <YOUR IMPLEMENTATION> Get the ground horizon line\n",
    "horizon_line = get_horizon_line()\n",
    "# <YOUR IMPLEMENTATION> Plot the ground horizon line\n",
    "plot_horizon_line()\n",
    "\n",
    "# Part 2\n",
    "# <YOUR IMPLEMENTATION> Solve for the camera parameters (f, u, v)\n",
    "f, u, v = get_camera_parameters()\n",
    "\n",
    "# Part 3\n",
    "# <YOUR IMPLEMENTATION> Solve for the rotation matrix\n",
    "R = get_rotation_matrix()\n",
    "\n",
    "# Part 4\n",
    "# Record image coordinates for each object and store in map\n",
    "objects = ('person', 'CSL building', 'the spike statue', 'the lamp posts')\n",
    "coords = dict()\n",
    "for obj in objects:\n",
    "    coords[obj] = get_top_and_bottom_coordinates(im, obj)\n",
    "\n",
    "# <YOUR IMPLEMENTATION> Estimate heights\n",
    "for obj in objects[1:]:\n",
    "    print('Estimating height of %s' % obj)\n",
    "    height = estimate_height()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
