1 #define _GNU_SOURCE /* struct ucred */
10 #include <sys/socket.h>
19 #define begins_with(s_, a_) (!strncmp(s_, a_, strlen(s_)))
23 logperror(const char *s
)
25 syslog(LOG_ERR
, "%s: %s", s
, strerror(errno
));
29 memory_limits(size_t *minuser
, size_t *mincomp
, size_t *maxcomp
, size_t *total
)
31 FILE *f
= fopen("/proc/meminfo", "r");
33 while (fgets(line
, sizeof(line
), f
)) {
34 if (begins_with("MemTotal:", line
)) {
36 sscanf(line
, "MemTotal:%zu", total
);
43 *minuser
= *total
* split_ratio
;
44 if (*minuser
< static_minfree
)
45 *minuser
= static_minfree
;
46 if (*minuser
> static_maxfree
)
47 *minuser
= static_maxfree
;
49 *mincomp
= static_minfree
;
50 ssize_t smaxcomp
= *total
- *minuser
;
51 *maxcomp
= smaxcomp
> 0 ? smaxcomp
: 0;
52 /* maxcomp < mincomp may happen; they are used in different
57 get_default_mem_limit(void)
59 size_t minuser
, mincomp
, maxcomp
, total
;
60 memory_limits(&minuser
, &mincomp
, &maxcomp
, &total
);
68 if (cgroup_setup(chier
, "memory") < 0)
70 int ret
= cgroup_create(chier
, cgroup
);
74 /* CGroup newly created, set limit. */
75 if (cgroup_set_mem_limit(chier
, cgroup
, get_default_mem_limit()) < 0)
82 mprintf(int fd
, char *fmt
, ...)
88 vsnprintf(buf
, sizeof(buf
), fmt
, v
);
93 .iov_len
= strlen(buf
),
99 ssize_t sent
= sendmsg(fd
, &msg
, 0);
101 logperror("sendmsg");
104 if ((size_t) sent
< iov
.iov_len
) {
105 syslog(LOG_INFO
, "incomplete send %zd < %zu, FIXME", sent
, iov
.iov_len
);
112 main(int argc
, char *argv
[])
114 /* Do this while everyone can still see the error. */
128 openlog("compctl", LOG_PID
, LOG_DAEMON
);
129 cgroup_perror
= logperror
;
133 int s
= socket(AF_UNIX
, SOCK_STREAM
, 0);
134 /* TODO: Protect against double execution? */
136 struct sockaddr_un sun
= { .sun_family
= AF_UNIX
, .sun_path
= SOCKFILE
};
137 if (bind(s
, (struct sockaddr
*) &sun
, sizeof(sun
.sun_family
) + strlen(sun
.sun_path
) + 1) < 0) {
141 chmod(SOCKFILE
, 0777);
145 while ((fd
= accept(s
, NULL
, NULL
)) >= 0) {
146 /* We handle only a single client at a time. This means
147 * that it is rather easy to write a script that will DOS
148 * the daemon, this is just an attack vector we ignore. */
149 /* TODO: alarm() to wake from stuck clients. */
151 /* Decode the message with command and credentials. */
153 int on
= 1; setsockopt(fd
, SOL_SOCKET
, SO_PASSCRED
, &on
, sizeof(on
));
156 char cbuf
[CMSG_SPACE(sizeof(*cred
))];
160 .iov_len
= sizeof(line
),
162 struct msghdr msg
= {
166 .msg_controllen
= sizeof(cbuf
),
170 int replylen
= recvmsg(fd
, &msg
, 0);
180 struct cmsghdr
*cmsg
;
181 cmsg
= CMSG_FIRSTHDR(&msg
);
182 if (cmsg
== NULL
|| cmsg
->cmsg_len
!= CMSG_LEN(sizeof(*cred
))) {
183 syslog(LOG_INFO
, "want %zu", CMSG_LEN(sizeof(*cred
)));
187 if (cmsg
->cmsg_level
!= SOL_SOCKET
|| cmsg
->cmsg_type
!= SCM_CREDENTIALS
) {
188 errmsg
= "cmsg designation";
191 cred
= (struct ucred
*) CMSG_DATA(cmsg
);
195 syslog(LOG_WARNING
, "protocol error (%s)", line
);
200 /* Analyze command */
201 if (!strcmp("blessme", line
)) {
202 syslog(LOG_INFO
, "new computation process %d", cred
->pid
);
203 if (cgroup_add_task(chier
, cgroup
, cred
->pid
) < 0)
204 mprintf(fd
, "0 error: %s", strerror(errno
));
206 mprintf(fd
, "1 blessed");
208 } else if (begins_with("kill ", line
)) {
209 pid_t pid
= atoi(line
+ strlen("kill "));
213 syslog(LOG_WARNING
, "kill: invalid pid (%d)", pid
);
214 mprintf(fd
, "0 invalid pid");
218 if (!cgroup_is_task_in_cgroup(chier
, cgroup
, pid
)) {
219 mprintf(fd
, "0 task not marked as computation");
224 syslog(LOG_INFO
, "killing process %d (request by pid %d uid %d)", pid
, cred
->pid
, cred
->uid
);
226 /* TODO: Grace period and then kill with SIGKILL. */
227 mprintf(fd
, "1 task killed");
229 } else if (!strcmp("killall", line
)) {
231 int tasks_n
= cgroup_task_list(chier
, cgroup
, &tasks
);
233 mprintf(fd
, "0 error: %s\r\n", strerror(errno
));
237 for (int i
= 0; i
< tasks_n
; i
++) {
238 syslog(LOG_INFO
, "killing process %d (mass request by pid %d uid %d)", tasks
[i
], cred
->pid
, cred
->uid
);
239 kill(tasks
[i
], SIGTERM
);
241 /* TODO: Grace period and then kill with SIGKILL. */
242 mprintf(fd
, "1 %d tasks killed", tasks_n
);
245 } else if (begins_with("limitmem ", line
)) {
246 size_t limit
= atol(line
+ strlen("limitmem "));
247 size_t minuser
, mincomp
, maxcomp
, total
;
248 memory_limits(&minuser
, &mincomp
, &maxcomp
, &total
);
250 if (limit
< mincomp
) {
251 mprintf(fd
, "-1 at least %zuM must remain available for computations.", mincomp
/ 1048576);
255 if (limit
> total
|| total
- limit
< minuser
) {
256 mprintf(fd
, "-2 at least %zuM must remain available for users; maximum limit for computations is %zuM.", minuser
/ 1048576, (total
- minuser
) / 1048576);
261 syslog(LOG_INFO
, "setting limit %zu (request by pid %d uid %d)", limit
, cred
->pid
, cred
->uid
);
262 if (cgroup_set_mem_limit(chier
, cgroup
, limit
) < 0)
263 mprintf(fd
, "0 error: %s", strerror(errno
));
265 mprintf(fd
, "1 limit set");
268 syslog(LOG_WARNING
, "invalid command (%s)", line
);