[sgn.git] / R / Nirs / nirs.R
1 # Runs `waves` package with data from the NIRS Breedbase tool to train phenotypic prediction models
4 # Jenna Hershberger (
6 # Load packages
7 library(dplyr)
8 library(magrittr)
9 library(devtools)
10 library(jsonlite)
11 library(waves)
12 library(stringr)
13 library(readr)
15 # Error handling
16 # TODO where do I check for genotype and/or environment overlap for the CVs? Here or in controller?
18 #### Read in and assign arguments ####
19 args <- commandArgs(trailingOnly = TRUE)
21 # args[1] = phenotype of interest (string with ontology name)
22 pheno <- args[1]
24 # args[2] = test preprocessing methods boolean
25 preprocessing <- ifelse(args[2]=="TRUE", TRUE, FALSE)
26 preprocessing.method <- ifelse(preprocessing, NULL, 1) # tells SaveModel() to use raw data if false
28 # args[3] = number of sampling iterations
29 num.iterations <- as.numeric(args[3])
31 # args[4] = model algorithm
32 model.method <- args[4]
34 # args[5] = tune length
35 tune.length <- as.numeric(args[5])
37 # args[6] = Random Forest variable importance
38 rf.var.importance <- ifelse(args[6]=="TRUE", TRUE, FALSE)
40 # args[7] = CV method as string
41 cv.scheme <- args[7]
42 ## Set cv.scheme to NULL if != CV1, CV2, CV0, or CV00
43 stratified.sampling <- TRUE
44 if(cv.scheme == "random"){
45   stratified.sampling <- FALSE
46   cv.scheme <- NULL
47 } else if(cv.scheme == "stratified"){
48   cv.scheme <- NULL
51 # args[8] = training data.frame: observationUnit level data with phenotypes and spectra in JSON format
52 train.input <- jsonlite::fromJSON(txt = args[8], flatten = T) %>%
53   rename("" = observationUnitId) %>%
54   rename_at(vars(starts_with("trait.")), ~paste0("reference")) %>%
55   rename_at(vars(starts_with("nirs_spectra")), ~str_replace(., "nirs_spectra.", "")) %>%
56   dplyr::select("", reference, starts_with("germplasm"), num_range(prefix = "X", range = 1:100000))
57 train.input$reference <- as.numeric(train.input$reference)
59 print(train.input[1:5,1:5])
61 # args[9] = test data.frame: observationUnit level data with phenotypes and spectra in JSON format
62 if(args[9] != "NULL"){
63     print("TEST DATA PROVIDED")
65   test.input <- jsonlite::fromJSON(txt = args[9], flatten = T) %>%
66     rename("" = observationUnitId) %>%
67     rename_at(vars(starts_with("trait.")), ~paste0("reference")) %>%
68     rename_at(vars(starts_with("nirs_spectra")), ~str_replace(., "nirs_spectra.", "")) %>%
69     dplyr::select("", reference, starts_with("germplasm"), num_range(prefix = "X", range = 1:100000))
70   test.input$reference <- as.numeric(test.input$reference)
71 } else{
72     print("NO TEST DATA")
73   test.input <- NULL
76 # args[10] = save model boolean
78   if(is.null(cv.scheme)){
79     train.ready <- train.input %>% dplyr::select(-germplasmName)
80     if(is.null(test.input)){
81       test.ready <- test.input
82       print(test.ready[1:5,1:5])
83     } else{
84       test.ready <- test.input %>% dplyr::select(-germplasmName)
85     }
87     print(train.ready[1:5,1:5])
88     print(train.ready$reference)
89     # Test model using non-specialized cv scheme
90     sm.output <- save_model(df = train.ready,
91                             write.model = FALSE,
92                             pretreatment = preprocessing.method,
93                    = NULL,
94                    = "PredictionModel",
95                             best.model.metric = "RMSE",
96                             tune.length = tune.length,
97                             model.method = model.method,
98                             num.iterations = num.iterations,
99                             stratified.sampling = stratified.sampling,
100                             cv.scheme = NULL,
101                             trial1 = NULL,
102                             trial2 = NULL,
103                             trial3 = NULL
104                             )
106   } else{
107     # Test model using specialized cv scheme AND SAVE
108     sm.output <- save_model(df = NULL,
109                             write.model = TRUE,
110                             pretreatment = preprocessing.method,
111                    = NULL,
112                    = "PredictionModel",
113                             best.model.metric = "RMSE",
114                             tune.length = tune.length,
115                             model.method = model.method,
116                             num.iterations = num.iterations,
117                             stratified.sampling = stratified.sampling,
118                             cv.scheme = cv.scheme,
119                             trial1 = train.input,
120                             trial2 = test.input,
121                             trial3 = NULL
122                             )
123   }
124   results.df <- sm.output$best.model.stats
125   saveRDS(sm.output$best.model, file = args[11]) # args[11] = model save location with .Rds in filename
127 } else{ # DON'T SAVE MODEL
128   if(is.null(cv.scheme)){
129     train.ready <- train.input %>% dplyr::select(-germplasmName)
130     if(is.null(test.input)){
131       test.ready <- test.input
132     } else{
133       test.ready <- test.input %>% dplyr::select(-germplasmName)
134     }
136     wls <- colnames(train.ready)[-c(1:2)] %>% parse_number() # take off sample name and reference columns
137     # Test model using non-specialized cv scheme
138     results.df <- TestModelPerformance( = train.ready, num.iterations = num.iterations,
139                               = test.ready, preprocessing = preprocessing,
140                                        wavelengths = wls, tune.length = tune.length,
141                                        model.method = model.method, output.summary = TRUE,
142                                        rf.variable.importance = rf.var.importance,
143                                        stratified.sampling = stratified.sampling, cv.scheme = NULL,
144                                        trial1 = NULL, trial2 = NULL, trial3 = NULL)
145   } else{
146     # Test model using specialized cv scheme
147     wls <- colnames(train.ready)[-c(1:3)] %>% parse_number() # take off sample name, reference, and genotype columns
148     results.df <- TestModelPerformance( = NULL, num.iterations = num.iterations,
149                               = NULL, preprocessing = preprocessing,
150                                        wavelengths = wls, tune.length = tune.length,
151                                        model.method = model.method, output.summary = TRUE,
152                                        rf.variable.importance = rf.var.importance,
153                                        stratified.sampling = FALSE, cv.scheme = cv.scheme,
154                                        trial1 = train.input, trial2 = test.input,
155                                        trial3 = NULL)
156   }
157   if(rf.var.importance){
158     var.imp <- results.df[[2]]
159     results.df <- results.df[[1]]
160     # args[12] = variable importance results output file name
161     write.csv(var.imp, file = args[12], row.names = F)
162     # args[13] = variable importance figure output file name
163     ggsave(filename = args[13], plot = results.plot, device = "png")
164   }
167 # args[14] = table output file name
168 write.csv(x = results.df, file = args[14], row.names = F)
170 # args[15] = figure output file name
171 # ggsave(filename = args[15], plot = results.plot, device = "png")