Create Batch Reset Hyperparameter tutorial notebook
[notebooks.git] / fmda / data / grib_to_geotiff.py
blob990ba367f65c7b31a04dbcb86da6432260544fa5
1 import argparse
2 from osgeo import gdal
3 import numpy as np
5 gdal.DontUseExceptions()
7 def grib_to_geotiff(input_filename, output_filename_base, band_number):
8 """
9 Convert existing grib file to geotiff.
10 :param input_filename: str, file name w relative path
11 :param output_filename_base: str, output file name that is later appended
12 :param band_number: integer
13 :return: None
14 """
16 # Open the GRIB file
17 ds = gdal.Open(input_filename)
19 # Get the specified band
20 band = ds.GetRasterBand(band_number)
21 metadata = band.GetMetadata()
23 # Print the metadata
24 print(f"Metadata for band {band_number}:")
25 for key, value in metadata.items():
26 print(f"{key}: {value}")
28 # Check if the band data is 2D
29 arr = band.ReadAsArray()
30 if len(arr.shape) != 2:
31 print(f"Skipping band {band_number} because it is not 2D")
32 return
34 # Create the output filename
35 output_filename = f"{output_filename_base}.{band_number}.tif"
37 # Create a new data source in memory
38 driver = gdal.GetDriverByName("GTiff")
39 out_ds = driver.Create(output_filename, band.XSize, band.YSize, 1, band.DataType)
41 # Set the geotransform and projection
42 out_ds.SetGeoTransform(ds.GetGeoTransform())
43 out_ds.SetProjection(ds.GetProjection())
45 # Write the data to the new data source
46 out_band = out_ds.GetRasterBand(1)
47 out_band.WriteArray(arr)
49 # Close the data sources
50 ds = None
51 out_ds = None
53 print(f"Band {band_number} from {input_filename} saved to {output_filename}")
56 if __name__ == '__main__':
57 # Parse command line arguments
58 parser = argparse.ArgumentParser(description='Extract a band from a GRIB file and save as a GeoTIFF if it is 2D.')
59 parser.add_argument('input_filename', help='Path to the input GRIB file')
60 parser.add_argument('output_filename_base', help='Base of the output GeoTIFF filename (band number will be appended)')
61 parser.add_argument('band_number', type=int, help='Number of the band to extract (1-based)')
63 args = parser.parse_args()
65 # Call the function with the arguments
66 grib_to_geotiff(args.input_filename, args.output_filename_base, args.band_number)