checked that fit works with timesteps > 1
[notebooks.git] / rnn.ipynb
blob6acddab3b93a6443efb5918157fbd8b9f8dfcccd
2   "nbformat": 4,
3   "nbformat_minor": 0,
4   "metadata": {
5     "kernelspec": {
6       "display_name": "Python 3",
7       "name": "python3"
8     },
9     "language_info": {
10       "name": "python"
11     },
12     "colab": {
13       "name": "rnn.ipynb",
14       "provenance": []
15     }
16   },
17   "cells": [
18     {
19       "cell_type": "markdown",
20       "metadata": {
21         "id": "am_B9iKqXsGX"
22       },
23       "source": [
24         "The following additional libraries are needed to run this\n",
25         "notebook. Note that running on Colab is experimental, please report a Github\n",
26         "issue if you have any problem."
27       ]
28     },
29     {
30       "cell_type": "code",
31       "metadata": {
32         "id": "joeZTMPsXsGb",
33         "outputId": "29116f84-4616-454e-dd7b-7d5476bd34cd",
34         "colab": {
35           "base_uri": "https://localhost:8080/"
36         }
37       },
38       "source": [
39         "!pip install d2l==0.17.0\n",
40         "!pip install -U mxnet-cu101==1.7.0\n"
41       ],
42       "execution_count": 1,
43       "outputs": [
44         {
45           "output_type": "stream",
46           "name": "stdout",
47           "text": [
48             "Collecting d2l==0.17.0\n",
49             "  Downloading d2l-0.17.0-py3-none-any.whl (83 kB)\n",
50             "\u001b[?25l\r\u001b[K     |████                            | 10 kB 12.2 MB/s eta 0:00:01\r\u001b[K     |███████▉                        | 20 kB 9.4 MB/s eta 0:00:01\r\u001b[K     |███████████▉                    | 30 kB 6.6 MB/s eta 0:00:01\r\u001b[K     |███████████████▊                | 40 kB 6.4 MB/s eta 0:00:01\r\u001b[K     |███████████████████▊            | 51 kB 3.7 MB/s eta 0:00:01\r\u001b[K     |███████████████████████▋        | 61 kB 3.8 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▋    | 71 kB 3.7 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▌| 81 kB 4.0 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 83 kB 664 kB/s \n",
51             "\u001b[?25hRequirement already satisfied: jupyter in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.0) (1.0.0)\n",
52             "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.0) (1.1.5)\n",
53             "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.0) (2.23.0)\n",
54             "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.0) (1.19.5)\n",
55             "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from d2l==0.17.0) (3.2.2)\n",
56             "Requirement already satisfied: notebook in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l==0.17.0) (5.3.1)\n",
57             "Requirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l==0.17.0) (5.1.1)\n",
58             "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l==0.17.0) (7.6.5)\n",
59             "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l==0.17.0) (5.2.0)\n",
60             "Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l==0.17.0) (5.6.1)\n",
61             "Requirement already satisfied: ipykernel in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l==0.17.0) (4.10.1)\n",
62             "Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter->d2l==0.17.0) (5.3.5)\n",
63             "Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter->d2l==0.17.0) (5.1.1)\n",
64             "Requirement already satisfied: traitlets>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter->d2l==0.17.0) (5.1.0)\n",
65             "Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter->d2l==0.17.0) (5.5.0)\n",
66             "Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (4.4.2)\n",
67             "Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (1.0.18)\n",
68             "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (57.4.0)\n",
69             "Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (4.8.0)\n",
70             "Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (0.7.5)\n",
71             "Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (0.8.1)\n",
72             "Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (2.6.1)\n",
73             "Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (0.2.5)\n",
74             "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipykernel->jupyter->d2l==0.17.0) (1.15.0)\n",
75             "Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter->d2l==0.17.0) (5.1.3)\n",
76             "Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter->d2l==0.17.0) (0.2.0)\n",
77             "Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter->d2l==0.17.0) (3.5.1)\n",
78             "Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter->d2l==0.17.0) (1.0.2)\n",
79             "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->jupyter->d2l==0.17.0) (2.6.0)\n",
80             "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->jupyter->d2l==0.17.0) (4.8.1)\n",
81             "Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter->d2l==0.17.0) (2.11.3)\n",
82             "Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter->d2l==0.17.0) (0.12.1)\n",
83             "Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter->d2l==0.17.0) (1.8.0)\n",
84             "Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel->jupyter->d2l==0.17.0) (22.3.0)\n",
85             "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel->jupyter->d2l==0.17.0) (2.8.2)\n",
86             "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook->jupyter->d2l==0.17.0) (0.7.0)\n",
87             "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->notebook->jupyter->d2l==0.17.0) (2.0.1)\n",
88             "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->d2l==0.17.0) (0.10.0)\n",
89             "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->d2l==0.17.0) (1.3.2)\n",
90             "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->d2l==0.17.0) (2.4.7)\n",
91             "Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l==0.17.0) (4.1.0)\n",
92             "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l==0.17.0) (0.3)\n",
93             "Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l==0.17.0) (0.7.1)\n",
94             "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l==0.17.0) (0.8.4)\n",
95             "Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l==0.17.0) (0.5.0)\n",
96             "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l==0.17.0) (1.5.0)\n",
97             "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->jupyter->d2l==0.17.0) (21.0)\n",
98             "Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->jupyter->d2l==0.17.0) (0.5.1)\n",
99             "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->d2l==0.17.0) (2018.9)\n",
100             "Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->d2l==0.17.0) (1.11.2)\n",
101             "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->d2l==0.17.0) (1.24.3)\n",
102             "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->d2l==0.17.0) (2021.5.30)\n",
103             "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->d2l==0.17.0) (3.0.4)\n",
104             "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->d2l==0.17.0) (2.10)\n",
105             "Installing collected packages: d2l\n",
106             "Successfully installed d2l-0.17.0\n",
107             "Collecting mxnet-cu101==1.7.0\n",
108             "  Downloading mxnet_cu101-1.7.0-py2.py3-none-manylinux2014_x86_64.whl (846.0 MB)\n",
109             "\u001b[K     |███████████████████████████████▌| 834.1 MB 1.1 MB/s eta 0:00:11tcmalloc: large alloc 1147494400 bytes == 0x559d9b964000 @  0x7f01279d4615 0x559d632ef4cc 0x559d633cf47a 0x559d632f22ed 0x559d633e3e1d 0x559d63365e99 0x559d633609ee 0x559d632f3bda 0x559d63365d00 0x559d633609ee 0x559d632f3bda 0x559d63362737 0x559d633e4c66 0x559d63361daf 0x559d633e4c66 0x559d63361daf 0x559d633e4c66 0x559d63361daf 0x559d632f4039 0x559d63337409 0x559d632f2c52 0x559d63365c25 0x559d633609ee 0x559d632f3bda 0x559d63362737 0x559d633609ee 0x559d632f3bda 0x559d63361915 0x559d632f3afa 0x559d63361c0d 0x559d633609ee\n",
110             "\u001b[K     |████████████████████████████████| 846.0 MB 20 kB/s \n",
111             "\u001b[?25hCollecting graphviz<0.9.0,>=0.8.1\n",
112             "  Downloading graphviz-0.8.4-py2.py3-none-any.whl (16 kB)\n",
113             "Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.7/dist-packages (from mxnet-cu101==1.7.0) (2.23.0)\n",
114             "Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.7/dist-packages (from mxnet-cu101==1.7.0) (1.19.5)\n",
115             "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101==1.7.0) (2.10)\n",
116             "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101==1.7.0) (2021.5.30)\n",
117             "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101==1.7.0) (3.0.4)\n",
118             "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101==1.7.0) (1.24.3)\n",
119             "Installing collected packages: graphviz, mxnet-cu101\n",
120             "  Attempting uninstall: graphviz\n",
121             "    Found existing installation: graphviz 0.10.1\n",
122             "    Uninstalling graphviz-0.10.1:\n",
123             "      Successfully uninstalled graphviz-0.10.1\n",
124             "Successfully installed graphviz-0.8.4 mxnet-cu101-1.7.0\n"
125           ]
126         }
127       ]
128     },
129     {
130       "cell_type": "markdown",
131       "metadata": {
132         "origin_pos": 0,
133         "id": "KEZ2fE5vXsGc"
134       },
135       "source": [
136         "# Recurrent Neural Networks\n",
137         ":label:`sec_rnn`\n",
138         "\n",
139         "\n",
140         "In :numref:`sec_language_model` we introduced $n$-gram models, where the conditional probability of word $x_t$ at time step $t$ only depends on the $n-1$ previous words.\n",
141         "If we want to incorporate the possible effect of words earlier than time step $t-(n-1)$ on $x_t$,\n",
142         "we need to increase $n$.\n",
143         "However, the number of model parameters would also increase exponentially with it, as we need to store $|\\mathcal{V}|^n$ numbers for a vocabulary set $\\mathcal{V}$.\n",
144         "Hence, rather than modeling $P(x_t \\mid x_{t-1}, \\ldots, x_{t-n+1})$ it is preferable to use a latent variable model:\n",
145         "\n",
146         "$$P(x_t \\mid x_{t-1}, \\ldots, x_1) \\approx P(x_t \\mid h_{t-1}),$$\n",
147         "\n",
148         "where $h_{t-1}$ is a *hidden state* (also known as a hidden variable) that stores the sequence information up to time step $t-1$.\n",
149         "In general,\n",
150         "the hidden state at any time step $t$ could be computed based on both the current input $x_{t}$ and the previous hidden state $h_{t-1}$:\n",
151         "\n",
152         "$$h_t = f(x_{t}, h_{t-1}).$$\n",
153         ":eqlabel:`eq_ht_xt`\n",
154         "\n",
155         "For a sufficiently powerful function $f$ in :eqref:`eq_ht_xt`, the latent variable model is not an approximation. After all, $h_t$ may simply store all the data it has observed so far.\n",
156         "However, it could potentially make both computation and storage expensive.\n",
157         "\n",
158         "Recall that we have discussed hidden layers with hidden units in :numref:`chap_perceptrons`.\n",
159         "It is noteworthy that\n",
160         "hidden layers and hidden states refer to two very different concepts.\n",
161         "Hidden layers are, as explained, layers that are hidden from view on the path from input to output.\n",
162         "Hidden states are technically speaking *inputs* to whatever we do at a given step,\n",
163         "and they can only be computed by looking at data at previous time steps.\n",
164         "\n",
165         "*Recurrent neural networks* (RNNs) are neural networks with hidden states. Before introducing the RNN model, we first revisit the MLP model introduced in :numref:`sec_mlp`.\n",
166         "\n",
167         "## Neural Networks without Hidden States\n",
168         "\n",
169         "Let us take a look at an MLP with a single hidden layer.\n",
170         "Let the hidden layer's activation function be $\\phi$.\n",
171         "Given a minibatch of examples $\\mathbf{X} \\in \\mathbb{R}^{n \\times d}$ with batch size $n$ and $d$ inputs, the hidden layer's output $\\mathbf{H} \\in \\mathbb{R}^{n \\times h}$ is calculated as\n",
172         "\n",
173         "$$\\mathbf{H} = \\phi(\\mathbf{X} \\mathbf{W}_{xh} + \\mathbf{b}_h).$$\n",
174         ":eqlabel:`rnn_h_without_state`\n",
175         "\n",
176         "In :eqref:`rnn_h_without_state`, we have the weight parameter $\\mathbf{W}_{xh} \\in \\mathbb{R}^{d \\times h}$, the bias parameter $\\mathbf{b}_h \\in \\mathbb{R}^{1 \\times h}$, and the number of hidden units $h$, for the hidden layer.\n",
177         "Thus, broadcasting (see :numref:`subsec_broadcasting`) is applied during the summation.\n",
178         "Next, the hidden variable $\\mathbf{H}$ is used as the input of the output layer. The output layer is given by\n",
179         "\n",
180         "$$\\mathbf{O} = \\mathbf{H} \\mathbf{W}_{hq} + \\mathbf{b}_q,$$\n",
181         "\n",
182         "where $\\mathbf{O} \\in \\mathbb{R}^{n \\times q}$ is the output variable, $\\mathbf{W}_{hq} \\in \\mathbb{R}^{h \\times q}$ is the weight parameter, and $\\mathbf{b}_q \\in \\mathbb{R}^{1 \\times q}$ is the bias parameter of the output layer.  If it is a classification problem, we can use $\\text{softmax}(\\mathbf{O})$ to compute the probability distribution of the output categories.\n",
183         "\n",
184         "This is entirely analogous to the regression problem we solved previously in :numref:`sec_sequence`, hence we omit details.\n",
185         "Suffice it to say that we can pick feature-label pairs at random and learn the parameters of our network via automatic differentiation and stochastic gradient descent.\n",
186         "\n",
187         "## Recurrent Neural Networks with Hidden States\n",
188         ":label:`subsec_rnn_w_hidden_states`\n",
189         "\n",
190         "Matters are entirely different when we have hidden states. Let us look at the structure in some more detail.\n",
191         "\n",
192         "Assume that we have\n",
193         "a minibatch of inputs\n",
194         "$\\mathbf{X}_t \\in \\mathbb{R}^{n \\times d}$\n",
195         "at time step $t$.\n",
196         "In other words,\n",
197         "for a minibatch of $n$ sequence examples,\n",
198         "each row of $\\mathbf{X}_t$ corresponds to one example at time step $t$ from the sequence.\n",
199         "Next,\n",
200         "denote by $\\mathbf{H}_t  \\in \\mathbb{R}^{n \\times h}$ the hidden variable of time step $t$.\n",
201         "Unlike the MLP, here we save the hidden variable $\\mathbf{H}_{t-1}$ from the previous time step and introduce a new weight parameter $\\mathbf{W}_{hh} \\in \\mathbb{R}^{h \\times h}$ to describe how to use the hidden variable of the previous time step in the current time step. Specifically, the calculation of the hidden variable of the current time step is determined by the input of the current time step together with the hidden variable of the previous time step:\n",
202         "\n",
203         "$$\\mathbf{H}_t = \\phi(\\mathbf{X}_t \\mathbf{W}_{xh} + \\mathbf{H}_{t-1} \\mathbf{W}_{hh}  + \\mathbf{b}_h).$$\n",
204         ":eqlabel:`rnn_h_with_state`\n",
205         "\n",
206         "Compared with :eqref:`rnn_h_without_state`, :eqref:`rnn_h_with_state` adds one more term $\\mathbf{H}_{t-1} \\mathbf{W}_{hh}$ and thus\n",
207         "instantiates :eqref:`eq_ht_xt`.\n",
208         "From the relationship between hidden variables $\\mathbf{H}_t$ and $\\mathbf{H}_{t-1}$ of adjacent time steps,\n",
209         "we know that these variables captured and retained the sequence's historical information up to their current time step, just like the state or memory of the neural network's current time step. Therefore, such a hidden variable is called a *hidden state*.\n",
210         "Since the hidden state uses the same definition of the previous time step in the current time step, the computation of :eqref:`rnn_h_with_state` is *recurrent*. Hence, neural networks with hidden states\n",
211         "based on recurrent computation are named\n",
212         "*recurrent neural networks*.\n",
213         "Layers that perform\n",
214         "the computation of :eqref:`rnn_h_with_state`\n",
215         "in RNNs\n",
216         "are called *recurrent layers*.\n",
217         "\n",
218         "\n",
219         "There are many different ways for constructing RNNs.\n",
220         "RNNs with a hidden state defined by :eqref:`rnn_h_with_state` are very common.\n",
221         "For time step $t$,\n",
222         "the output of the output layer is similar to the computation in the MLP:\n",
223         "\n",
224         "$$\\mathbf{O}_t = \\mathbf{H}_t \\mathbf{W}_{hq} + \\mathbf{b}_q.$$\n",
225         "\n",
226         "Parameters of the RNN\n",
227         "include the weights $\\mathbf{W}_{xh} \\in \\mathbb{R}^{d \\times h}, \\mathbf{W}_{hh} \\in \\mathbb{R}^{h \\times h}$,\n",
228         "and the bias $\\mathbf{b}_h \\in \\mathbb{R}^{1 \\times h}$\n",
229         "of the hidden layer,\n",
230         "together with the weights $\\mathbf{W}_{hq} \\in \\mathbb{R}^{h \\times q}$\n",
231         "and the bias $\\mathbf{b}_q \\in \\mathbb{R}^{1 \\times q}$\n",
232         "of the output layer.\n",
233         "It is worth mentioning that\n",
234         "even at different time steps,\n",
235         "RNNs always use these model parameters.\n",
236         "Therefore, the parameterization cost of an RNN\n",
237         "does not grow as the number of time steps increases.\n",
238         "\n",
239         ":numref:`fig_rnn` illustrates the computational logic of an RNN at three adjacent time steps.\n",
240         "At any time step $t$,\n",
241         "the computation of the hidden state can be treated as:\n",
242         "(i) concatenating the input $\\mathbf{X}_t$ at the current time step $t$ and the hidden state $\\mathbf{H}_{t-1}$ at the previous time step $t-1$;\n",
243         "(ii) feeding the concatenation result into a fully-connected layer with the activation function $\\phi$.\n",
244         "The output of such a fully-connected layer is the hidden state $\\mathbf{H}_t$ of the current time step $t$.\n",
245         "In this case,\n",
246         "the model parameters are the concatenation of $\\mathbf{W}_{xh}$ and $\\mathbf{W}_{hh}$, and a bias of $\\mathbf{b}_h$, all from :eqref:`rnn_h_with_state`.\n",
247         "The hidden state of the current time step $t$, $\\mathbf{H}_t$, will participate in computing the hidden state $\\mathbf{H}_{t+1}$ of the next time step $t+1$.\n",
248         "What is more, $\\mathbf{H}_t$ will also be\n",
249         "fed into the fully-connected output layer\n",
250         "to compute the output\n",
251         "$\\mathbf{O}_t$ of the current time step $t$.\n",
252         "\n",
253         "![An RNN with a hidden state.](http://d2l.ai/_images/rnn.svg)\n",
254         ":label:`fig_rnn`\n",
255         "\n",
256         "We just mentioned that the calculation of $\\mathbf{X}_t \\mathbf{W}_{xh} + \\mathbf{H}_{t-1} \\mathbf{W}_{hh}$ for the hidden state is equivalent to\n",
257         "matrix multiplication of\n",
258         "concatenation of $\\mathbf{X}_t$ and $\\mathbf{H}_{t-1}$\n",
259         "and\n",
260         "concatenation of $\\mathbf{W}_{xh}$ and $\\mathbf{W}_{hh}$.\n",
261         "Though this can be proven in mathematics,\n",
262         "in the following we just use a simple code snippet to show this.\n",
263         "To begin with,\n",
264         "we define matrices `X`, `W_xh`, `H`, and `W_hh`, whose shapes are (3, 1), (1, 4), (3, 4), and (4, 4), respectively.\n",
265         "Multiplying `X` by `W_xh`, and `H` by `W_hh`, respectively, and then adding these two multiplications,\n",
266         "we obtain a matrix of shape (3, 4).\n"
267       ]
268     },
269     {
270       "cell_type": "code",
271       "metadata": {
272         "origin_pos": 1,
273         "tab": [
274           "mxnet"
275         ],
276         "id": "rw0ShhAhXsGe"
277       },
278       "source": [
279         "from mxnet import np, npx\n",
280         "from d2l import mxnet as d2l\n",
281         "\n",
282         "npx.set_np()"
283       ],
284       "execution_count": null,
285       "outputs": []
286     },
287     {
288       "cell_type": "code",
289       "metadata": {
290         "origin_pos": 4,
291         "tab": [
292           "mxnet"
293         ],
294         "id": "FCWF9xdqXsGe",
295         "outputId": "b2cbc992-be6b-41c5-dd5c-e35f72eb2a25"
296       },
297       "source": [
298         "X, W_xh = np.random.normal(0, 1, (3, 1)), np.random.normal(0, 1, (1, 4))\n",
299         "H, W_hh = np.random.normal(0, 1, (3, 4)), np.random.normal(0, 1, (4, 4))\n",
300         "np.dot(X, W_xh) + np.dot(H, W_hh)"
301       ],
302       "execution_count": null,
303       "outputs": [
304         {
305           "data": {
306             "text/plain": [
307               "array([[-0.21952868,  4.256434  ,  4.5812645 , -5.344988  ],\n",
308               "       [ 3.4478583 , -3.0177274 , -1.6777471 ,  7.535347  ],\n",
309               "       [ 2.239007  ,  1.4199957 ,  4.744728  , -8.421293  ]])"
310             ]
311           },
312           "execution_count": 2,
313           "metadata": {},
314           "output_type": "execute_result"
315         }
316       ]
317     },
318     {
319       "cell_type": "markdown",
320       "metadata": {
321         "origin_pos": 6,
322         "id": "9n2kju1SXsGf"
323       },
324       "source": [
325         "Now we concatenate the matrices `X` and `H`\n",
326         "along columns (axis 1),\n",
327         "and the matrices\n",
328         "`W_xh` and `W_hh` along rows (axis 0).\n",
329         "These two concatenations\n",
330         "result in\n",
331         "matrices of shape (3, 5)\n",
332         "and of shape (5, 4), respectively.\n",
333         "Multiplying these two concatenated matrices,\n",
334         "we obtain the same output matrix of shape (3, 4)\n",
335         "as above.\n"
336       ]
337     },
338     {
339       "cell_type": "code",
340       "metadata": {
341         "origin_pos": 7,
342         "tab": [
343           "mxnet"
344         ],
345         "id": "veGhxnrkXsGf",
346         "outputId": "ab21230c-7c03-49de-dbe1-d5b13457b002"
347       },
348       "source": [
349         "np.dot(np.concatenate((X, H), 1), np.concatenate((W_xh, W_hh), 0))"
350       ],
351       "execution_count": null,
352       "outputs": [
353         {
354           "data": {
355             "text/plain": [
356               "array([[-0.2195287,  4.256434 ,  4.5812645, -5.344988 ],\n",
357               "       [ 3.4478583, -3.0177271, -1.677747 ,  7.535347 ],\n",
358               "       [ 2.2390068,  1.4199957,  4.744728 , -8.421294 ]])"
359             ]
360           },
361           "execution_count": 3,
362           "metadata": {},
363           "output_type": "execute_result"
364         }
365       ]
366     },
367     {
368       "cell_type": "markdown",
369       "metadata": {
370         "origin_pos": 8,
371         "id": "IJeB_1DXXsGg"
372       },
373       "source": [
374         "## RNN-based Character-Level Language Models\n",
375         "\n",
376         "Recall that for language modeling in :numref:`sec_language_model`,\n",
377         "we aim to predict the next token based on\n",
378         "the current and past tokens,\n",
379         "thus we shift the original sequence by one token\n",
380         "as the labels.\n",
381         "Bengio et al. first proposed\n",
382         "to use a neural network for language modeling :cite:`Bengio.Ducharme.Vincent.ea.2003`.\n",
383         "In the following we illustrate how RNNs can be used to build a language model.\n",
384         "Let the minibatch size be one, and the sequence of the text be \"machine\".\n",
385         "To simplify training in subsequent sections,\n",
386         "we tokenize text into characters rather than words\n",
387         "and consider a *character-level language model*.\n",
388         ":numref:`fig_rnn_train` demonstrates how to predict the next character based on the current and previous characters via an RNN for character-level language modeling.\n",
389         "\n",
390         "![A character-level language model based on the RNN. The input and label sequences are \"machin\" and \"achine\", respectively.](https://github.com/d2l-ai/d2l-en-colab/blob/master/img/rnn-train.svg?raw=1)\n",
391         ":label:`fig_rnn_train`\n",
392         "\n",
393         "During the training process,\n",
394         "we run a softmax operation on the output from the output layer for each time step, and then use the cross-entropy loss to compute the error between the model output and the label.\n",
395         "Due to the recurrent computation of the hidden state in the hidden layer, the output of time step 3 in :numref:`fig_rnn_train`,\n",
396         "$\\mathbf{O}_3$, is determined by the text sequence \"m\", \"a\", and \"c\". Since the next character of the sequence in the training data is \"h\", the loss of time step 3 will depend on the probability distribution of the next character generated based on the feature sequence \"m\", \"a\", \"c\" and the label \"h\" of this time step.\n",
397         "\n",
398         "In practice, each token is represented by a $d$-dimensional vector, and we use a batch size $n>1$. Therefore, the input $\\mathbf X_t$ at time step $t$ will be a $n\\times d$ matrix, which is identical to what we discussed in :numref:`subsec_rnn_w_hidden_states`.\n",
399         "\n",
400         "\n",
401         "## Perplexity\n",
402         ":label:`subsec_perplexity`\n",
403         "\n",
404         "Last, let us discuss about how to measure the language model quality, which will be used to evaluate our RNN-based models in the subsequent sections.\n",
405         "One way is to check how surprising the text is.\n",
406         "A good language model is able to predict with\n",
407         "high-accuracy tokens that what we will see next.\n",
408         "Consider the following continuations of the phrase \"It is raining\", as proposed by different language models:\n",
409         "\n",
410         "1. \"It is raining outside\"\n",
411         "1. \"It is raining banana tree\"\n",
412         "1. \"It is raining piouw;kcj pwepoiut\"\n",
413         "\n",
414         "In terms of quality, example 1 is clearly the best. The words are sensible and logically coherent.\n",
415         "While it might not quite accurately reflect which word follows semantically (\"in San Francisco\" and \"in winter\" would have been perfectly reasonable extensions), the model is able to capture which kind of word follows.\n",
416         "Example 2 is considerably worse by producing a nonsensical extension. Nonetheless, at least the model has learned how to spell words and some degree of correlation between words. Last, example 3 indicates a poorly trained model that does not fit data properly.\n",
417         "\n",
418         "We might measure the quality of the model by computing  the likelihood of the sequence.\n",
419         "Unfortunately this is a number that is hard to understand and difficult to compare.\n",
420         "After all, shorter sequences are much more likely to occur than the longer ones,\n",
421         "hence evaluating the model on Tolstoy's magnum opus\n",
422         "*War and Peace* will inevitably produce a much smaller likelihood than, say, on Saint-Exupery's novella *The Little Prince*. What is missing is the equivalent of an average.\n",
423         "\n",
424         "Information theory comes handy here.\n",
425         "We have defined entropy, surprisal, and cross-entropy\n",
426         "when we introduced the softmax regression\n",
427         "(:numref:`subsec_info_theory_basics`)\n",
428         "and more of information theory is discussed in the [online appendix on information theory](https://d2l.ai/chapter_appendix-mathematics-for-deep-learning/information-theory.html).\n",
429         "If we want to compress text, we can ask about\n",
430         "predicting the next token given the current set of tokens.\n",
431         "A better language model should allow us to predict the next token more accurately.\n",
432         "Thus, it should allow us to spend fewer bits in compressing the sequence.\n",
433         "So we can measure it by the cross-entropy loss averaged\n",
434         "over all the $n$ tokens of a sequence:\n",
435         "\n",
436         "$$\\frac{1}{n} \\sum_{t=1}^n -\\log P(x_t \\mid x_{t-1}, \\ldots, x_1),$$\n",
437         ":eqlabel:`eq_avg_ce_for_lm`\n",
438         "\n",
439         "where $P$ is given by a language model and $x_t$ is the actual token observed at time step $t$ from the sequence.\n",
440         "This makes the performance on documents of different lengths comparable. For historical reasons, scientists in natural language processing prefer to use a quantity called *perplexity*. In a nutshell, it is the exponential of :eqref:`eq_avg_ce_for_lm`:\n",
441         "\n",
442         "$$\\exp\\left(-\\frac{1}{n} \\sum_{t=1}^n \\log P(x_t \\mid x_{t-1}, \\ldots, x_1)\\right).$$\n",
443         "\n",
444         "Perplexity can be best understood as the harmonic mean of the number of real choices that we have when deciding which token to pick next. Let us look at a number of cases:\n",
445         "\n",
446         "* In the best case scenario, the model always perfectly estimates the probability of the label token as 1. In this case the perplexity of the model is 1.\n",
447         "* In the worst case scenario, the model always predicts the probability of the label token as 0. In this situation, the perplexity is positive infinity.\n",
448         "* At the baseline, the model predicts a uniform distribution over all the available tokens of the vocabulary. In this case, the perplexity equals the number of unique tokens of the vocabulary. In fact, if we were to store the sequence without any compression, this would be the best we could do to encode it. Hence, this provides a nontrivial upper bound that any useful model must beat.\n",
449         "\n",
450         "In the following sections, we will implement RNNs\n",
451         "for character-level language models and use perplexity\n",
452         "to evaluate such models.\n",
453         "\n",
454         "\n",
455         "## Summary\n",
456         "\n",
457         "* A neural network that uses recurrent computation for hidden states is called a recurrent neural network (RNN).\n",
458         "* The hidden state of an RNN can capture historical information of the sequence up to the current time step.\n",
459         "* The number of RNN model parameters does not grow as the number of time steps increases.\n",
460         "* We can create character-level language models using an  RNN.\n",
461         "* We can use perplexity to evaluate the quality of language models.\n",
462         "\n",
463         "## Exercises\n",
464         "\n",
465         "1. If we use an RNN to predict the next character in a text sequence, what is the required dimension for any output?\n",
466         "1. Why can RNNs express the conditional probability of a token at some time step based on all the previous tokens in the text sequence?\n",
467         "1. What happens to the gradient if you backpropagate through a long sequence?\n",
468         "1. What are some of the problems associated with the language model described in this section?\n"
469       ]
470     },
471     {
472       "cell_type": "markdown",
473       "metadata": {
474         "origin_pos": 9,
475         "tab": [
476           "mxnet"
477         ],
478         "id": "Q6DkPSonXsGg"
479       },
480       "source": [
481         "[Discussions](https://discuss.d2l.ai/t/337)\n"
482       ]
483     }
484   ]