Read pickle files with util
[notebooks.git] / fmda / test-plk2train.ipynb
blobc0409b588cb223e138c658d61983509d78715d60
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "540bb001-bad1-4969-933f-5c3100060732",
6    "metadata": {},
7    "source": [
8     "# v1 training on pkl data (April 2024 in Boise)"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
15    "metadata": {
16     "tags": []
17    },
18    "outputs": [],
19    "source": [
20     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
21     "import pickle\n",
22     "import logging\n",
23     "import os.path as osp\n",
24     "from moisture_rnn_pkl import pkl2train, run_rnn_pkl\n",
25     "from moisture_rnn import create_rnn_data_2 \n",
26     "from utils import hash2"
27    ]
28   },
29   {
30    "cell_type": "code",
31    "execution_count": null,
32    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
33    "metadata": {},
34    "outputs": [],
35    "source": [
36     "logging_setup()"
37    ]
38   },
39   {
40    "cell_type": "code",
41    "execution_count": null,
42    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
43    "metadata": {},
44    "outputs": [],
45    "source": [
46     "file_names=[\"reproducibility_dict2.pickle\",'test_NW_202401.pkl','test_CA_202401.pkl']\n",
47     "file_dir='data'\n",
48     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
49    ]
50   },
51   {
52    "cell_type": "code",
53    "execution_count": null,
54    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
55    "metadata": {},
56    "outputs": [],
57    "source": [
58     "# read/write control\n",
59     "train_file='train.pkl'\n",
60     "train_create=True   # if false, read\n",
61     "train_write=True\n",
62     "train_read=True"
63    ]
64   },
65   {
66    "cell_type": "code",
67    "execution_count": null,
68    "id": "62e31f16-d887-4552-bc2b-6024992c0a0b",
69    "metadata": {},
70    "outputs": [],
71    "source": [
72     "# print_dict_summary(train)"
73    ]
74   },
75   {
76    "cell_type": "code",
77    "execution_count": null,
78    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
79    "metadata": {},
80    "outputs": [],
81    "source": [
82     "if train_create:\n",
83     "    logging.info('creating the training cases from files %s',file_paths)\n",
84     "    # osp.join works on windows too, joins paths using \\ or /\n",
85     "    train = pkl2train(file_paths)\n",
86     "if train_write:\n",
87     "    with open(train_file, 'wb') as file:\n",
88     "        logging.info('Writing the rain cases into file %s',train_file)\n",
89     "        pickle.dump(train, file)\n",
90     "if train_read:\n",
91     "    logging.info('Reading the train cases from file %s',train_file)\n",
92     "    with open(train_file,'rb') as file:\n",
93     "        train=pickle.load(file)"
94    ]
95   },
96   {
97    "cell_type": "code",
98    "execution_count": null,
99    "id": "698df86b-8550-4135-81df-45dbf503dd4e",
100    "metadata": {},
101    "outputs": [],
102    "source": [
103     "from module_param_sets import param_sets"
104    ]
105   },
106   {
107    "cell_type": "code",
108    "execution_count": null,
109    "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
110    "metadata": {},
111    "outputs": [],
112    "source": [
113     "param_sets_keys=['0']\n",
114     "cases=[list(train.keys())[0]]\n",
115     "cases=list(train.keys())[0:10]\n",
116     "cases"
117    ]
118   },
119   {
120    "cell_type": "code",
121    "execution_count": null,
122    "id": "dd22baf2-59d2-460e-8c47-b20116dd5982",
123    "metadata": {},
124    "outputs": [],
125    "source": [
126     "logging.info('Running over parameter sets %s',param_sets_keys)\n",
127     "logging.info('Running over cases %s',cases)"
128    ]
129   },
130   {
131    "cell_type": "code",
132    "execution_count": null,
133    "id": "dc5b47bd-4fbc-44b8-b2dd-d118e068b450",
134    "metadata": {},
135    "outputs": [],
136    "source": [
137     "for i in param_sets_keys:\n",
138     "    for case in cases:\n",
139     "        logging.info('Processing case %s',case)\n",
140     "        print_dict_summary(train[case])\n",
141     "        logging.info('Misc fixes, change later')\n",
142     "        param_sets[i]['initialize']=False\n",
143     "        hours=train[case]['X'].shape[0]\n",
144     "        train[case]['hours']=hours\n",
145     "        train[case]['h2']   =hours     # not doing prediction yet                \n",
146     "        train[case]['Y'] = train[case]['Y'].reshape(-1, 1)\n",
147     "        print_dict_summary(train[case])\n",
148     "        run_rnn_pkl(train[case],param_sets[i])"
149    ]
150   },
151   {
152    "cell_type": "code",
153    "execution_count": null,
154    "id": "15384e4d-b8ec-4700-bdc2-83b0433d11c9",
155    "metadata": {},
156    "outputs": [],
157    "source": [
158     "logging.info('test-plk2train.ipynb done')"
159    ]
160   },
161   {
162    "cell_type": "code",
163    "execution_count": null,
164    "id": "ad5dae6c-1269-4674-a49e-2efe8b956911",
165    "metadata": {},
166    "outputs": [],
167    "source": []
168   }
169  ],
170  "metadata": {
171   "kernelspec": {
172    "display_name": "Python 3 (ipykernel)",
173    "language": "python",
174    "name": "python3"
175   },
176   "language_info": {
177    "codemirror_mode": {
178     "name": "ipython",
179     "version": 3
180    },
181    "file_extension": ".py",
182    "mimetype": "text/x-python",
183    "name": "python",
184    "nbconvert_exporter": "python",
185    "pygments_lexer": "ipython3",
186    "version": "3.10.9"
187   }
188  },
189  "nbformat": 4,
190  "nbformat_minor": 5