1 //===- TrainingLoggerTest.cpp - test for TrainingLogger -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/Utils/TrainingLogger.h"
10 #include "google/protobuf/struct.pb.h"
11 #include "tensorflow/core/example/example.pb.h"
12 #include "tensorflow/core/example/feature.pb.h"
13 #include "llvm/Analysis/TensorSpec.h"
14 #include "llvm/AsmParser/Parser.h"
15 #include "llvm/IR/Dominators.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/LLVMContext.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/Support/Path.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "llvm/Testing/Support/SupportHelpers.h"
22 #include "gtest/gtest.h"
26 extern const char *TestMainArgv0
;
28 // NOTE! This test model is currently also used by test/Transforms/Inline/ML
30 //- relevant if updating this model.
32 #define PROTO_CHECKER(FNAME, TYPE, INDEX, EXP) \
34 const auto &V = Expected.feature_lists() \
40 for (auto I = 0; I < V.size(); ++I) \
41 EXPECT_EQ(V.at(I), EXP[I]); \
44 TEST(TrainingLoggerTest
, Logger
) {
45 std::vector
<LoggedFeatureSpec
> Features
;
47 {TensorSpec::createSpec
<float>("the_float", {2, 3}), None
});
48 Features
.push_back({TensorSpec::createSpec
<int64_t>("the_int", {2}),
49 std::string("alternate_name")});
51 auto Rewards
= TensorSpec::createSpec
<float>("reward", {1});
52 Logger
L(Features
, Rewards
, true);
53 const float F00
[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
54 const int64_t F01
[]{2, 3};
56 L
.logFloatValue(0, F00
);
57 L
.logInt64Value(1, F01
);
58 L
.logFloatReward(3.4);
59 const float F10
[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
60 const int64_t F11
[]{-2, -3};
61 L
.logFloatValue(0, F10
);
62 L
.logInt64Value(1, F11
);
63 L
.logFloatReward(-3.0);
65 raw_string_ostream
OS(Result
);
68 tensorflow::SequenceExample Expected
;
69 ASSERT_TRUE(Expected
.ParseFromString(Result
));
70 PROTO_CHECKER("the_float", float_list
, 0, F00
);
71 PROTO_CHECKER("the_float", float_list
, 1, F10
);
72 PROTO_CHECKER("alternate_name", int64_list
, 0, F01
);
73 PROTO_CHECKER("alternate_name", int64_list
, 1, F11
);
76 PROTO_CHECKER("reward", float_list
, 0, R0
);
77 PROTO_CHECKER("reward", float_list
, 1, R1
);
80 TEST(TrainingLoggerTest
, LoggerInt32FeaturesAndReward
) {
81 std::vector
<LoggedFeatureSpec
> Features
;
83 {TensorSpec::createSpec
<float>("the_float", {2, 3}), None
});
84 Features
.push_back({TensorSpec::createSpec
<int32_t>("the_int", {2}),
85 std::string("alternate_name")});
87 auto Rewards
= TensorSpec::createSpec
<int32_t>("reward", {1});
88 Logger
L(Features
, Rewards
, true);
89 const float F00
[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
90 const int32_t F01
[]{2, 3};
92 L
.logFloatValue(0, F00
);
93 L
.logInt32Value(1, F01
);
95 const float F10
[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
96 const int32_t F11
[]{-2, -3};
97 L
.logFloatValue(0, F10
);
98 L
.logInt32Value(1, F11
);
101 raw_string_ostream
OS(Result
);
104 tensorflow::SequenceExample Expected
;
105 ASSERT_TRUE(Expected
.ParseFromString(Result
));
106 PROTO_CHECKER("the_float", float_list
, 0, F00
);
107 PROTO_CHECKER("the_float", float_list
, 1, F10
);
108 PROTO_CHECKER("alternate_name", int64_list
, 0, F01
);
109 PROTO_CHECKER("alternate_name", int64_list
, 1, F11
);
112 PROTO_CHECKER("reward", int64_list
, 0, R0
);
113 PROTO_CHECKER("reward", int64_list
, 1, R1
);
116 TEST(TrainingLoggerTest
, LoggerNoReward
) {
117 std::vector
<LoggedFeatureSpec
> Features
;
119 {TensorSpec::createSpec
<float>("the_float", {2, 3}), None
});
120 Features
.push_back({TensorSpec::createSpec
<int64_t>("the_int", {2}),
121 std::string("alternate_name")});
123 auto Rewards
= TensorSpec::createSpec
<float>("reward", {1});
124 Logger
L(Features
, Rewards
, false);
125 const float F00
[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
126 const int64_t F01
[]{2, 3};
128 L
.logFloatValue(0, F00
);
129 L
.logInt64Value(1, F01
);
130 const float F10
[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
131 const int64_t F11
[]{-2, -3};
132 L
.logFloatValue(0, F10
);
133 L
.logInt64Value(1, F11
);
136 raw_string_ostream
OS(Result
);
138 tensorflow::SequenceExample Expected
;
139 ASSERT_TRUE(Expected
.ParseFromString(Result
));
140 PROTO_CHECKER("the_float", float_list
, 0, F00
);
141 PROTO_CHECKER("the_float", float_list
, 1, F10
);
142 PROTO_CHECKER("alternate_name", int64_list
, 0, F01
);
143 PROTO_CHECKER("alternate_name", int64_list
, 1, F11
);
146 TEST(TrainingLoggerTest
, LoggerFinalReward
) {
147 std::vector
<LoggedFeatureSpec
> Features
;
148 Features
.push_back({TensorSpec::createSpec
<float>("the_float", {1}), None
});
149 Features
.push_back({TensorSpec::createSpec
<int64_t>("the_int", {1}), None
});
151 auto Rewards
= TensorSpec::createSpec
<float>("reward", {1});
152 Logger
L(Features
, Rewards
, true);
153 for (int64_t I
= 0; I
< 3; ++I
) {
154 float F
= static_cast<float>(I
);
155 L
.logFloatValue(0, &F
);
156 L
.logInt64Value(1, &I
);
158 L
.logFloatFinalReward(3.14);
160 raw_string_ostream
OS(Result
);
162 const float Zero
[]{0.0};
163 const float R
[]{3.14};
164 tensorflow::SequenceExample Expected
;
165 ASSERT_TRUE(Expected
.ParseFromString(Result
));
166 PROTO_CHECKER("reward", float_list
, 0, Zero
);
167 PROTO_CHECKER("reward", float_list
, 1, Zero
);
168 PROTO_CHECKER("reward", float_list
, 2, R
);
171 TEST(TrainingLoggerTest
, LoggerGroup
) {
172 std::vector
<LoggedFeatureSpec
> Features
;
173 Features
.push_back({TensorSpec::createSpec
<float>("the_float", {1}), None
});
174 Features
.push_back({TensorSpec::createSpec
<int64_t>("the_int", {1}), None
});
176 auto Rewards
= TensorSpec::createSpec
<float>("reward", {1});
177 StringMap
<std::unique_ptr
<Logger
>> Loggers
;
178 std::vector
<std::string
> Names
{"a", "b"};
180 for (auto Name
: Names
) {
181 auto L
= std::make_unique
<Logger
>(Features
, Rewards
, true);
182 for (int64_t I
= 0; I
< 3; ++I
) {
183 float F
= static_cast<float>(I
) + Bump
;
184 L
->logFloatValue(0, &F
);
185 L
->logInt64Value(1, &I
);
187 L
->logFloatFinalReward(3.14 + Bump
);
188 Loggers
.insert(std::make_pair(Name
, std::move(L
)));
191 raw_string_ostream
OS(Result
);
192 Logger::flushLogs(OS
, Loggers
);
193 google::protobuf::Struct Expected
;
194 ASSERT_TRUE(Expected
.ParseFromString(Result
));
195 EXPECT_EQ(Expected
.fields_size(), 2);
196 EXPECT_TRUE(Expected
.fields().contains("a"));
197 EXPECT_TRUE(Expected
.fields().contains("b"));