forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjit_log.cpp
138 lines (117 loc) · 3.83 KB
/
jit_log.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <cstdlib>
#include <iomanip>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/serialization/python_print.h>
namespace torch {
namespace jit {
std::string TORCH_API get_jit_logging_levels() {
return JitLoggingConfig::getInstance().getLoggingLevels();
}
void TORCH_API set_jit_logging_levels(std::string level) {
JitLoggingConfig::getInstance().setLoggingLevels(level);
}
// gets a string representation of a node header
// (e.g. outputs, a node kind and outputs)
std::string getHeader(const Node* node) {
std::stringstream ss;
node->print(ss, 0, {}, false, false, false, false);
return ss.str();
}
void JitLoggingConfig::parse() {
std::stringstream in_ss;
in_ss << "function:" << this->logging_levels;
std::unordered_map<std::string, size_t> new_files_to_levels;
std::string line;
while (std::getline(in_ss, line, ':')) {
if (line.size() == 0) {
continue;
}
auto index_at = line.find_last_of('>');
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
auto end_index = line.find_last_of('.') == std::string::npos
? line.size()
: line.find_last_of('.');
auto filename = line.substr(begin_index, end_index - begin_index);
new_files_to_levels.insert({filename, logging_level});
}
this->files_to_levels = new_files_to_levels;
}
bool is_enabled(const char* cfname, JitLoggingLevels level) {
const std::unordered_map<std::string, size_t> files_to_levels =
JitLoggingConfig::getInstance().getFilesToLevels();
std::string fname{cfname};
fname = c10::detail::StripBasename(fname);
auto end_index = fname.find_last_of('.') == std::string::npos
? fname.size()
: fname.find_last_of('.');
auto fname_no_ext = fname.substr(0, end_index);
auto it = files_to_levels.find(fname_no_ext);
if (it == files_to_levels.end()) {
return false;
}
return level <= static_cast<JitLoggingLevels>(it->second);
}
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
// we won't have access to an original function, so we have to construct
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();
}
std::string jit_log_prefix(
const std::string& prefix,
const std::string& in_str) {
std::stringstream in_ss(in_str);
std::stringstream out_ss;
std::string line;
while (std::getline(in_ss, line)) {
out_ss << prefix << line << std::endl;
}
return out_ss.str();
}
std::string jit_log_prefix(
JitLoggingLevels level,
const char* fn,
int l,
const std::string& in_str) {
std::stringstream prefix_ss;
prefix_ss << "[";
prefix_ss << level << " ";
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
prefix_ss << std::setfill('0') << std::setw(3) << l;
prefix_ss << "] ";
return jit_log_prefix(prefix_ss.str(), in_str);
}
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
switch (level) {
case JitLoggingLevels::GRAPH_DUMP:
out << "DUMP";
break;
case JitLoggingLevels::GRAPH_UPDATE:
out << "UPDATE";
break;
case JitLoggingLevels::GRAPH_DEBUG:
out << "DEBUG";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Invalid level");
}
return out;
}
} // namespace jit
} // namespace torch