1 # Runs `waves` package with data from the NIRS Breedbase tool to train phenotypic prediction models
4 # Jenna Hershberger (jmh579@cornell.edu)
16 # TODO where do I check for genotype and/or environment overlap for the CVs? Here or in controller?
18 #### Read in and assign arguments ####
19 args <- commandArgs(trailingOnly = TRUE)
21 # args[1] = phenotype of interest (string with ontology name)
24 # args[2] = test preprocessing methods boolean
25 preprocessing <- ifelse(args[2]=="TRUE", TRUE, FALSE)
26 preprocessing.method <- ifelse(preprocessing, NULL, 1) # tells SaveModel() to use raw data if false
28 # args[3] = number of sampling iterations
29 num.iterations <- as.numeric(args[3])
31 # args[4] = model algorithm
32 model.method <- args[4]
34 # args[5] = tune length
35 tune.length <- as.numeric(args[5])
37 # args[6] = Random Forest variable importance
38 rf.var.importance <- ifelse(args[6]=="TRUE", TRUE, FALSE)
40 # args[7] = CV method as string
42 ## Set cv.scheme to NULL if != CV1, CV2, CV0, or CV00
43 stratified.sampling <- TRUE
44 if(cv.scheme == "random"){
45 stratified.sampling <- FALSE
47 } else if(cv.scheme == "stratified"){
51 # args[8] = training data.frame: observationUnit level data with phenotypes and spectra in JSON format
52 train.input <- jsonlite::fromJSON(txt = args[8], flatten = T) %>%
53 rename("unique.id" = observationUnitId) %>%
54 rename_at(vars(starts_with("trait.")), ~paste0("reference")) %>%
55 rename_at(vars(starts_with("nirs_spectra")), ~str_replace(., "nirs_spectra.", "")) %>%
56 dplyr::select("unique.id", reference, starts_with("germplasm"), num_range(prefix = "X", range = 1:100000))
57 train.input$reference <- as.numeric(train.input$reference)
59 print(train.input[1:5,1:5])
61 # args[9] = test data.frame: observationUnit level data with phenotypes and spectra in JSON format
62 if(args[9] != "NULL"){
63 print("TEST DATA PROVIDED")
65 test.input <- jsonlite::fromJSON(txt = args[9], flatten = T) %>%
66 rename("unique.id" = observationUnitId) %>%
67 rename_at(vars(starts_with("trait.")), ~paste0("reference")) %>%
68 rename_at(vars(starts_with("nirs_spectra")), ~str_replace(., "nirs_spectra.", "")) %>%
69 dplyr::select("unique.id", reference, starts_with("germplasm"), num_range(prefix = "X", range = 1:100000))
70 test.input$reference <- as.numeric(test.input$reference)
76 # args[10] = save model boolean
77 if(args[10]=="TRUE"){ # SAVE MODEL WITHOUT CV.SCHEME
78 if(is.null(cv.scheme)){
79 train.ready <- train.input %>% dplyr::select(-germplasmName)
80 if(is.null(test.input)){
81 test.ready <- test.input
82 print(test.ready[1:5,1:5])
84 test.ready <- test.input %>% dplyr::select(-germplasmName)
87 print(train.ready[1:5,1:5])
88 print(train.ready$reference)
89 # Test model using non-specialized cv scheme
90 sm.output <- save_model(df = train.ready,
92 pretreatment = preprocessing.method,
93 model.save.folder = NULL,
94 model.name = "PredictionModel",
95 best.model.metric = "RMSE",
96 tune.length = tune.length,
97 model.method = model.method,
98 num.iterations = num.iterations,
99 stratified.sampling = stratified.sampling,
107 # Test model using specialized cv scheme AND SAVE
108 sm.output <- save_model(df = NULL,
110 pretreatment = preprocessing.method,
111 model.save.folder = NULL,
112 model.name = "PredictionModel",
113 best.model.metric = "RMSE",
114 tune.length = tune.length,
115 model.method = model.method,
116 num.iterations = num.iterations,
117 stratified.sampling = stratified.sampling,
118 cv.scheme = cv.scheme,
119 trial1 = train.input,
124 results.df <- sm.output$best.model.stats
125 saveRDS(sm.output$best.model, file = args[11]) # args[11] = model save location with .Rds in filename
127 } else{ # DON'T SAVE MODEL
128 if(is.null(cv.scheme)){
129 train.ready <- train.input %>% dplyr::select(-germplasmName)
130 if(is.null(test.input)){
131 test.ready <- test.input
133 test.ready <- test.input %>% dplyr::select(-germplasmName)
136 wls <- colnames(train.ready)[-c(1:2)] %>% parse_number() # take off sample name and reference columns
137 # Test model using non-specialized cv scheme
138 results.df <- TestModelPerformance(train.data = train.ready, num.iterations = num.iterations,
139 test.data = test.ready, preprocessing = preprocessing,
140 wavelengths = wls, tune.length = tune.length,
141 model.method = model.method, output.summary = TRUE,
142 rf.variable.importance = rf.var.importance,
143 stratified.sampling = stratified.sampling, cv.scheme = NULL,
144 trial1 = NULL, trial2 = NULL, trial3 = NULL)
146 # Test model using specialized cv scheme
147 wls <- colnames(train.ready)[-c(1:3)] %>% parse_number() # take off sample name, reference, and genotype columns
148 results.df <- TestModelPerformance(train.data = NULL, num.iterations = num.iterations,
149 test.data = NULL, preprocessing = preprocessing,
150 wavelengths = wls, tune.length = tune.length,
151 model.method = model.method, output.summary = TRUE,
152 rf.variable.importance = rf.var.importance,
153 stratified.sampling = FALSE, cv.scheme = cv.scheme,
154 trial1 = train.input, trial2 = test.input,
157 if(rf.var.importance){
158 var.imp <- results.df[[2]]
159 results.df <- results.df[[1]]
160 # args[12] = variable importance results output file name
161 write.csv(var.imp, file = args[12], row.names = F)
162 # args[13] = variable importance figure output file name
163 ggsave(filename = args[13], plot = results.plot, device = "png")
167 # args[14] = table output file name
168 write.csv(x = results.df, file = args[14], row.names = F)
170 # args[15] = figure output file name
171 # ggsave(filename = args[15], plot = results.plot, device = "png")