1 // SPDX-License-Identifier: GPL-2.0
5 #include <internal/rc_check.h>
6 #include <linux/refcount.h>
7 #include <linux/zalloc.h>
10 DECLARE_RC_STRUCT(comm_str
) {
15 static struct comm_strs
{
16 struct rw_semaphore lock
;
17 struct comm_str
**strs
;
22 static void comm_strs__remove_if_last(struct comm_str
*cs
);
24 static void comm_strs__init(void)
26 init_rwsem(&_comm_strs
.lock
);
27 _comm_strs
.capacity
= 16;
28 _comm_strs
.num_strs
= 0;
29 _comm_strs
.strs
= calloc(16, sizeof(*_comm_strs
.strs
));
32 static struct comm_strs
*comm_strs__get(void)
34 static pthread_once_t comm_strs_type_once
= PTHREAD_ONCE_INIT
;
36 pthread_once(&comm_strs_type_once
, comm_strs__init
);
41 static refcount_t
*comm_str__refcnt(struct comm_str
*cs
)
43 return &RC_CHK_ACCESS(cs
)->refcnt
;
46 static const char *comm_str__str(const struct comm_str
*cs
)
48 return &RC_CHK_ACCESS(cs
)->str
[0];
51 static struct comm_str
*comm_str__get(struct comm_str
*cs
)
53 struct comm_str
*result
;
55 if (RC_CHK_GET(result
, cs
))
56 refcount_inc_not_zero(comm_str__refcnt(cs
));
61 static void comm_str__put(struct comm_str
*cs
)
66 if (refcount_dec_and_test(comm_str__refcnt(cs
))) {
69 if (refcount_read(comm_str__refcnt(cs
)) == 1)
70 comm_strs__remove_if_last(cs
);
76 static struct comm_str
*comm_str__new(const char *str
)
78 struct comm_str
*result
= NULL
;
79 RC_STRUCT(comm_str
) *cs
;
81 cs
= malloc(sizeof(*cs
) + strlen(str
) + 1);
82 if (ADD_RC_CHK(result
, cs
)) {
83 refcount_set(comm_str__refcnt(result
), 1);
84 strcpy(&cs
->str
[0], str
);
89 static int comm_str__search(const void *_key
, const void *_member
)
91 const char *key
= _key
;
92 const struct comm_str
*member
= *(const struct comm_str
* const *)_member
;
94 return strcmp(key
, comm_str__str(member
));
97 static void comm_strs__remove_if_last(struct comm_str
*cs
)
99 struct comm_strs
*comm_strs
= comm_strs__get();
101 down_write(&comm_strs
->lock
);
103 * Are there only references from the array, if so remove the array
104 * reference under the write lock so that we don't race with findnew.
106 if (refcount_read(comm_str__refcnt(cs
)) == 1) {
107 struct comm_str
**entry
;
109 entry
= bsearch(comm_str__str(cs
), comm_strs
->strs
, comm_strs
->num_strs
,
110 sizeof(struct comm_str
*), comm_str__search
);
111 comm_str__put(*entry
);
112 for (int i
= entry
- comm_strs
->strs
; i
< comm_strs
->num_strs
- 1; i
++)
113 comm_strs
->strs
[i
] = comm_strs
->strs
[i
+ 1];
114 comm_strs
->num_strs
--;
116 up_write(&comm_strs
->lock
);
119 static struct comm_str
*__comm_strs__find(struct comm_strs
*comm_strs
, const char *str
)
121 struct comm_str
**result
;
123 result
= bsearch(str
, comm_strs
->strs
, comm_strs
->num_strs
, sizeof(struct comm_str
*),
129 return comm_str__get(*result
);
132 static struct comm_str
*comm_strs__findnew(const char *str
)
134 struct comm_strs
*comm_strs
= comm_strs__get();
135 struct comm_str
*result
;
140 down_read(&comm_strs
->lock
);
141 result
= __comm_strs__find(comm_strs
, str
);
142 up_read(&comm_strs
->lock
);
146 down_write(&comm_strs
->lock
);
147 result
= __comm_strs__find(comm_strs
, str
);
149 if (comm_strs
->num_strs
== comm_strs
->capacity
) {
150 struct comm_str
**tmp
;
152 tmp
= reallocarray(comm_strs
->strs
,
153 comm_strs
->capacity
+ 16,
154 sizeof(*comm_strs
->strs
));
156 up_write(&comm_strs
->lock
);
159 comm_strs
->strs
= tmp
;
160 comm_strs
->capacity
+= 16;
162 result
= comm_str__new(str
);
164 int low
= 0, high
= comm_strs
->num_strs
- 1;
165 int insert
= comm_strs
->num_strs
; /* Default to inserting at the end. */
167 while (low
<= high
) {
168 int mid
= low
+ (high
- low
) / 2;
169 int cmp
= strcmp(comm_str__str(comm_strs
->strs
[mid
]), str
);
178 memmove(&comm_strs
->strs
[insert
+ 1], &comm_strs
->strs
[insert
],
179 (comm_strs
->num_strs
- insert
) * sizeof(struct comm_str
*));
180 comm_strs
->num_strs
++;
181 comm_strs
->strs
[insert
] = result
;
184 up_write(&comm_strs
->lock
);
185 return comm_str__get(result
);
188 struct comm
*comm__new(const char *str
, u64 timestamp
, bool exec
)
190 struct comm
*comm
= zalloc(sizeof(*comm
));
195 comm
->start
= timestamp
;
198 comm
->comm_str
= comm_strs__findnew(str
);
199 if (!comm
->comm_str
) {
207 int comm__override(struct comm
*comm
, const char *str
, u64 timestamp
, bool exec
)
209 struct comm_str
*new, *old
= comm
->comm_str
;
211 new = comm_strs__findnew(str
);
216 comm
->comm_str
= new;
217 comm
->start
= timestamp
;
224 void comm__free(struct comm
*comm
)
226 comm_str__put(comm
->comm_str
);
230 const char *comm__str(const struct comm
*comm
)
232 return comm_str__str(comm
->comm_str
);