Slow work on colourspace decoding
[jirac.git] / org / diracvideo / Jirac / Motion.java
blobb861b9eb639b2423be3638e01db7aa924a6a1934
1 package org.diracvideo.Jirac;
2 import java.awt.Dimension;
3 import java.awt.Point;
5 /** Motion
7 * An ill-named class representing an object
8 * which does motion compensation prediction
9 * on a picture. **/
10 class Motion {
11 Parameters par;
12 Vector vecs[];
13 Block refs[][];
14 Arithmetic ar[];
15 int xbsep, ybsep, xblen, yblen, xoffset, yoffset;
16 int chroma_h_shift, chroma_v_shift;
17 short weight_x[], weight_y[], obmc[];
18 Block block, tmp_ref[];
20 static int ARITH_SUPERBLOCK = 0;
21 static int ARITH_PRED_MODE = 1;
22 static int ARITH_REF1_X = 2;
23 static int ARITH_REF1_Y = 3;
24 static int ARITH_REF2_X = 4;
25 static int ARITH_REF2_Y = 5;
26 static int ARITH_DC_0 = 6;
27 static int ARITH_DC_1 = 7;
28 static int ARITH_DC_2 = 8;
30 public Motion(Parameters p, Buffer bufs[], Block frames[][]) {
31 par = p;
32 refs = frames;
33 vecs = new Vector[par.x_num_blocks * par.y_num_blocks];
34 tmp_ref = new Block[refs.length];
35 ar = new Arithmetic[9];
36 for(int i = 0; i < 9; i++)
37 if(bufs[i] != null) ar[i] = new Arithmetic(bufs[i]);
41 public void decode() {
42 for(int y = 0; y < par.y_num_blocks; y += 4)
43 for(int x = 0; x < par.x_num_blocks; x += 4)
44 decodeMacroBlock(x,y);
47 private void decodeMacroBlock(int x, int y) {
48 int split = splitPrediction(x,y);
49 Vector mv = getVector(x,y);
50 mv.split = (split + ar[ARITH_SUPERBLOCK].decodeUint(Context.SB_F1, Context.SB_DATA))%3;
51 switch(mv.split) {
52 case 0:
53 decodePredictionUnit(mv, x, y);
54 for(int i = 0; i < 4; i++)
55 for(int j = 0; j < 4; j++)
56 setVector(mv, x + j, y + i);
57 break;
58 case 1:
59 for(int i = 0; i < 4; i += 2)
60 for(int j = 0; j < 4; j += 2) {
61 mv = getVector(x + j, y + i);
62 mv.split = 1;
63 decodePredictionUnit(mv, x + j, y + i);
64 setVector(mv, x + j + 1, y + i);
65 setVector(mv, x + j, y + i + 1);
66 setVector(mv, x + j + 1, y + i + 1);
68 break;
69 case 2:
70 for(int i = 0; i < 4; i++)
71 for(int j = 0; j < 4; j++) {
72 mv = getVector(x + j, y + i);
73 mv.split = 2;
74 decodePredictionUnit(mv, x + j, y + i);
76 break;
77 default:
78 throw new Error("Unsupported splitting mode");
82 private void decodePredictionUnit(Vector mv, int x, int y) {
83 mv.pred_mode = modePrediction(x,y);
84 mv.pred_mode ^= ar[ARITH_PRED_MODE].decodeBit(Context.BLOCK_MODE_REF1);
85 if(par.num_refs > 1) {
86 mv.pred_mode ^= (ar[ARITH_PRED_MODE].decodeBit(Context.BLOCK_MODE_REF2) << 1);
88 if(mv.pred_mode == 0) {
89 int pred[] = new int[3];
90 dcPrediction(x,y,pred);
91 mv.dc[0] = pred[0] +
92 ar[ARITH_DC_0].decodeSint(Context.LUMA_DC_CONT_BIN1,
93 Context.LUMA_DC_VALUE,
94 Context.LUMA_DC_SIGN);
95 mv.dc[1] = pred[1] +
96 ar[ARITH_DC_1].decodeSint(Context.CHROMA1_DC_CONT_BIN1,
97 Context.CHROMA1_DC_VALUE,
98 Context.CHROMA1_DC_SIGN);
99 mv.dc[2] = pred[2] +
100 ar[ARITH_DC_2].decodeSint(Context.CHROMA2_DC_CONT_BIN1,
101 Context.CHROMA2_DC_VALUE,
102 Context.CHROMA2_DC_SIGN);
103 } else {
104 int pred_x, pred_y;
105 if(par.have_global_motion) {
106 int pred = globalPrediction(x,y);
107 pred ^= ar[ARITH_SUPERBLOCK].decodeBit(Context.GLOBAL_BLOCK);
108 mv.using_global = (pred == 0 ? false : true);
109 } else {
110 mv.using_global = false;
112 if(!mv.using_global) {
113 if((mv.pred_mode & 1) != 0) {
114 vectorPrediction(mv,x,y,1);
115 mv.dx[0] +=
116 ar[ARITH_REF1_X].decodeSint(Context.MV_REF1_H_CONT_BIN1,
117 Context.MV_REF1_H_VALUE,
118 Context.MV_REF1_H_SIGN);
119 mv.dy[0] +=
120 ar[ARITH_REF1_Y].decodeSint(Context.MV_REF1_V_CONT_BIN1,
121 Context.MV_REF1_V_VALUE,
122 Context.MV_REF1_V_SIGN);
124 if((mv.pred_mode & 2) != 0) {
125 vectorPrediction(mv, x, y, 2);
126 mv.dx[1] += ar[ARITH_REF2_X].decodeSint(Context.MV_REF2_H_CONT_BIN1,
127 Context.MV_REF2_H_VALUE,
128 Context.MV_REF2_H_SIGN);
129 mv.dy[1] += ar[ARITH_REF2_Y].decodeSint(Context.MV_REF2_V_CONT_BIN1,
130 Context.MV_REF2_V_VALUE,
131 Context.MV_REF2_V_SIGN);
136 mv.using_global = false;
137 mv.dx[0] = 0;
138 mv.dy[0] = 0;
139 mv.dx[1] = 0;
140 mv.dy[1] = 0;
143 public void render(Block out[], VideoFormat f) {
144 for(int k = 0; k < out.length; k++) {
145 initializeRender(k,f);
146 block = new Block(new Dimension(xblen, yblen));
147 for(int i = 0; i < par.num_refs; i++)
148 tmp_ref[i] = refs[i][k];
149 for(int j = 0; j < par.y_num_blocks; j++)
150 for(int i = 0; i < par.x_num_blocks; i++) {
151 predictBlock(out[k], i, j, k);
152 accumulateBlock(out[k], i*xbsep - xoffset,
153 j*ybsep - yoffset);
155 out[k].shiftOut(6,0);
156 out[k].clip(7);
160 private void initializeRender(int k, VideoFormat f) {
161 chroma_h_shift = f.chromaHShift();
162 chroma_v_shift = f.chromaVShift();
163 yblen = par.yblen_luma;
164 xblen = par.xblen_luma;
165 ybsep = par.ybsep_luma;
166 xbsep = par.xbsep_luma;
167 if(k != 0) {
168 yblen >>= chroma_v_shift;
169 ybsep >>= chroma_v_shift;
170 xbsep >>= chroma_h_shift;
171 xblen >>= chroma_h_shift;
173 yoffset = (yblen - ybsep) >> 1;
174 xoffset = (xblen - xbsep) >> 1;
175 /* initialize obmc weight */
176 weight_y = new short[yblen];
177 weight_x = new short[xblen];
178 obmc = new short[xblen*yblen];
179 for(int i = 0; i < xblen; i++) {
180 short wx;
181 if(xoffset == 0) {
182 wx = 8;
183 } else if( i < 2*xoffset) {
184 wx = Util.getRamp(i, xoffset);
185 } else if(xblen - 1 - i < 2*xoffset) {
186 wx = Util.getRamp(xblen - 1 - i, xoffset);
187 } else {
188 wx = 8;
190 weight_x[i] = wx;
192 for(int j = 0; j < yblen; j++) {
193 short wy;
194 if(yoffset == 0) {
195 wy = 8;
196 } else if(j < 2*yoffset) {
197 wy = Util.getRamp(j, yoffset);
198 } else if(yblen - 1 - j < 2*yoffset) {
199 wy = Util.getRamp(yblen - 1 - j, yoffset);
200 } else {
201 wy = 8;
203 weight_y[j] = wy;
207 private void predictBlock(Block out, int i, int j, int k) {
208 int xstart = (i*xbsep) - xoffset,
209 ystart = (j*ybsep) - yoffset;
210 Vector mv = getVector(i,j);
211 if(mv.pred_mode == 0) {
212 for(int q = 0; j < yblen; j++)
213 for(int p = 0; i < xblen; i++)
214 block.set(p, q, (mv.dc[k]));
216 if(k != 0 && !mv.using_global)
217 mv = mv.scale(chroma_h_shift, chroma_v_shift);
218 for(int q = 0; q < yblen; q++) {
219 int y = ystart + q;
220 if(y < 0 || y > out.s.height - 1) continue;
221 for(int p = 0; p < xblen; p++) {
222 int x = xstart + p;
223 if(x < 0 || x > out.s.width - 1) continue;
224 block.set(p,q, predictPixel(mv, x, y, k));
229 private short predictPixel(Vector mv, int x, int y, int k) {
230 if(mv.using_global) {
231 for(int i = 0; i < par.num_refs; i++) {
232 par.global[i].getVector(mv, x, y, i);
234 if(k != 0)
235 mv = mv.scale(chroma_h_shift, chroma_v_shift);
237 short weight = (short)(par.picture_weight_1 + par.picture_weight_2);
238 short val = 0;
239 int px, py;
240 switch(mv.pred_mode) {
241 case 1:
242 px = (x << par.mv_precision) + mv.dx[0];
243 py = (y << par.mv_precision) + mv.dy[0];
244 val = (short)(weight*predictSubPixel(0, px, py));
245 break;
246 case 2:
247 px = (x << par.mv_precision) + mv.dx[1];
248 py = (y << par.mv_precision) + mv.dy[1];
249 val = (short)(weight*predictSubPixel(1, px, py));
250 break;
251 case 3:
252 px = (x << par.mv_precision) + mv.dx[0];
253 py = (y << par.mv_precision) + mv.dy[0];
254 val = (short)(par.picture_weight_1*predictSubPixel(0, px, py));
255 px = (x << par.mv_precision) + mv.dx[1];
256 py = (x << par.mv_precision) + mv.dy[1];
257 val += (short)(par.picture_weight_2*predictSubPixel(1, px, py));
258 default:
259 break;
261 return (short)Util.roundShift(val, par.picture_weight_bits);
264 private short predictSubPixel(int ref, int px, int py) {
265 if(par.mv_precision < 2) {
266 return tmp_ref[ref].real(px, py);
268 int prec = par.mv_precision;
269 int add = 1 << (prec - 1);
270 int hx = px >> (prec-1);
271 int hy = py >> (prec-1);
272 int rx = px - (hx << (prec-1));
273 int ry = py - (hy << (prec-1));
274 int w00,w01, w10, w11;
275 w00 = (add - rx)*(add - ry);
276 w01 = (add - rx)*ry;
277 w10 = rx*(add - ry);
278 w11 = rx*ry;
279 int val = w00*tmp_ref[ref].real(hx, hy) +
280 w01*tmp_ref[ref].real(hx + 1, hy) +
281 w10*tmp_ref[ref].real(hx, hy + 1) +
282 w11*tmp_ref[ref].real(hx + 1, hy + 1);
283 return (short)((val + (1 << (2*prec-3))) >> (2*prec - 2));
287 private void accumulateBlock(Block out, int x, int y) {
288 if(!edge(x,y)) {
289 for(int q = 0; q < yblen; q++) {
290 if(q + y < 0 || q + y >= out.s.height) continue;
291 int outLine = out.index(x, y + q);
292 int inLine = block.line(q);
293 for(int p = 0; p < xblen; p++) {
294 if(p + x < 0 || p + x >= out.s.width) continue;
295 out.d[outLine + p] +=
296 (short)(weight_x[p]*weight_y[q]*block.d[inLine+p]);
299 } else {
300 int w_x, w_y;
301 for(int q = 0; q < yblen; q++) {
302 if(q + y < 0 || q + y >= out.s.height) continue;
303 if((y < 0 && q < 2*yoffset) ||
304 (y >= par.y_num_blocks*ybsep - yoffset &&
305 yblen - 1 - q < 2*yoffset))
306 w_y = 8;
307 else
308 w_y = weight_y[q];
309 int outLine = out.index(x, y + q);
310 int inLine = block.line(q);
311 for(int p = 0; p < xblen; p++) {
312 if(p + x < 0 || p + x >= out.s.width) continue;
313 if((x < 0 && p < 2*xoffset) ||
314 (x >= par.x_num_blocks*xbsep - xoffset &&
315 xblen - 1 - p < 2*xoffset))
316 w_x = 8;
317 else
318 w_x = weight_x[p];
319 out.d[outLine + p] +=
320 (short)(w_x*w_y*block.d[inLine+p]);
328 private boolean edge(int x, int y) {
329 return (x < 0 || x >= par.x_num_blocks*xbsep - xoffset)
330 || (y < 0 || y >= par.y_num_blocks*ybsep - yoffset);
334 private int splitPrediction(int x, int y) {
335 if(y == 0) {
336 if(x == 0) {
337 return 0;
338 } else {
339 return vecs[x-4].split;
341 } else {
342 if(x == 0) {
343 return getVector(0, y - 4).split;
344 } else {
345 int sum = 0;
346 sum += getVector(x, y - 4).split;
347 sum += getVector(x - 4, y).split;
348 sum += getVector(x - 4, y - 4).split;
349 return (sum+1)/3;
354 private int modePrediction(int x, int y) {
355 if(y == 0) {
356 if(x == 0) {
357 return 0;
358 } else {
359 return vecs[x - 1].pred_mode;
361 } else {
362 if(x == 0) {
363 return getVector(0, y - 1).pred_mode;
364 } else {
365 int a,b,c;
366 a = getVector(x - 1, y).pred_mode;
367 b = getVector(x, y - 1).pred_mode;
368 c = getVector(x - 1, y - 1).pred_mode;
369 return (a&b)|(b&c)|(c&a);
374 private int globalPrediction(int x, int y) {
375 if(x == 0 && y == 0) {
376 return 0;
378 if(y == 0) {
379 return vecs[x-1].using_global ? 1 : 0;
381 if(x == 0) {
382 return getVector(0, y-1).using_global ? 1 : 0;
384 int sum = 0;
385 sum += getVector(x - 1, y).using_global ? 1 : 0;
386 sum += getVector(x, y - 1).using_global ? 1 : 0;
387 sum += getVector(x - 1, y - 1).using_global ? 1 : 0;
388 return (sum >= 2) ? 1 : 0;
391 private void vectorPrediction(Vector mv, int x, int y, int mode) {
392 int n = 0, vx[] = new int[3], vy[] = new int[3];
393 if(x > 0) {
394 Vector ov = getVector(x-1, y);
395 if(!ov.using_global &&
396 (ov.pred_mode & mode) != 0) {
397 vx[n] = ov.dx[mode-1];
398 vy[n] = ov.dx[mode-1];
399 n++;
402 if(y > 0) {
403 Vector ov = getVector(x, y-1);
404 if(!ov.using_global &&
405 (ov.pred_mode & mode) != 0) {
406 vx[n] = ov.dx[mode-1];
407 vy[n] = ov.dx[mode-1];
408 n++;
411 if(x > 0 && y > 0) {
412 Vector ov = getVector(x - 1, y - 1);
413 if(!ov.using_global &&
414 (ov.pred_mode & mode) != 0) {
415 vx[n] = ov.dx[mode-1];
416 vy[n] = ov.dy[mode-1];
417 n++;
420 switch(n) {
421 case 0:
422 mv.dx[mode-1] = 0;
423 mv.dy[mode-1] = 0;
424 break;
425 case 1:
426 mv.dx[mode-1] = vx[0];
427 mv.dy[mode-1] = vy[0];
428 break;
429 case 2:
430 mv.dx[mode-1] = (vx[0] + vx[1] + 1) >> 1;
431 mv.dy[mode-1] = (vy[0] + vy[1] + 1) >> 1;
432 break;
433 case 3:
434 mv.dx[mode-1] = Util.median(vx);
435 mv.dy[mode-1] = Util.median(vy);
436 break;
440 private void dcPrediction(int x, int y, int pred[]) {
441 for(int i = 0; i < 3; i++) {
442 int sum = 0, n = 0;
443 if(x > 0) {
444 Vector ov = getVector(x - 1, y);
445 if(ov.pred_mode == 0) {
446 sum += ov.dc[i];
447 n++;
450 if(y > 0) {
451 Vector ov = getVector(x, y - 1);
452 if(ov.pred_mode == 0) {
453 sum += ov.dc[i];
454 n++;
457 if(x > 0 && y > 0) {
458 Vector ov = getVector(x - 1, y - 1);
459 if(ov.pred_mode == 0) {
460 sum += ov.dc[i];
461 n++;
464 switch(n) {
465 case 0:
466 pred[i] = 0;
467 break;
468 case 1:
469 pred[i] = sum;
470 break;
471 case 2:
472 pred[i] = (sum+1)>>1;
473 break;
474 case 3:
475 pred[i] = (sum+1)/3;
476 break;
481 private final Vector getVector(int x, int y) {
482 int pos = x + y*par.x_num_blocks;
483 if(vecs[pos] == null) {
484 vecs[pos] = new Vector();
486 return vecs[pos];
489 private final void setVector(Vector mv, int x, int y) {
490 vecs[x + y*par.x_num_blocks] = mv;