Renamed package to Jirac
[jirac.git] / org / diracvideo / Jirac / Motion.java
blob5d6f46a5afd1439765c8cbc6738326431bd1618a
1 package org.diracvideo.Jirac;
2 import java.awt.Dimension;
3 import java.awt.Point;
5 /** Motion
7 * An ill-named class representing an object
8 * which does motion compensation prediction
9 * on a picture. **/
10 class Motion {
11 Parameters par;
12 Vector vecs[];
13 Picture refs[];
14 Arithmetic ar[];
15 int xbsep, ybsep, xblen, yblen, xoffset, yoffset;
16 int chroma_h_shift, chroma_v_shift;
17 short weight_x[], weight_y[], obmc[];
18 Block block, tmp_ref[];
20 static int ARITH_SUPERBLOCK = 0;
21 static int ARITH_PRED_MODE = 1;
22 static int ARITH_REF1_X = 2;
23 static int ARITH_REF1_Y = 3;
24 static int ARITH_REF2_X = 4;
25 static int ARITH_REF2_Y = 5;
26 static int ARITH_DC_0 = 6;
27 static int ARITH_DC_1 = 7;
28 static int ARITH_DC_2 = 8;
30 public Motion(Parameters p, Picture r[]) {
31 par = p;
32 refs = r;
33 vecs = new Vector[par.x_num_blocks * par.y_num_blocks];
34 tmp_ref = new Block[refs.length];
37 private void initialize(Buffer bufs[]) {
38 ar = new Arithmetic[9];
39 for(int i = 0; i < 9; i++) {
40 if(bufs[i] == null)
41 continue;
42 ar[i] = new Arithmetic(bufs[i]);
46 public void decode(Buffer bufs[]) {
47 initialize(bufs);
48 for(int y = 0; y < par.y_num_blocks; y += 4)
49 for(int x = 0; x < par.x_num_blocks; x += 4)
50 decodeMacroBlock(x,y);
53 private void decodeMacroBlock(int x, int y) {
54 int split = splitPrediction(x,y);
55 Vector mv = getVector(x,y);
56 mv.split = (split + ar[ARITH_SUPERBLOCK].decodeUint(Context.SB_F1, Context.SB_DATA))%3;
57 switch(mv.split) {
58 case 0:
59 decodePredictionUnit(mv, x, y);
60 for(int i = 0; i < 4; i++)
61 for(int j = 0; j < 4; j++)
62 setVector(mv, x + j, y + i);
63 break;
64 case 1:
65 for(int i = 0; i < 4; i += 2)
66 for(int j = 0; j < 4; j += 2) {
67 mv = getVector(x + j, y + i);
68 mv.split = 1;
69 decodePredictionUnit(mv, x + j, y + i);
70 setVector(mv, x + j + 1, y + i);
71 setVector(mv, x + j, y + i + 1);
72 setVector(mv, x + j + 1, y + i + 1);
74 break;
75 case 2:
76 for(int i = 0; i < 4; i++)
77 for(int j = 0; j < 4; j++) {
78 mv = getVector(x + j, y + i);
79 mv.split = 2;
80 decodePredictionUnit(mv, x + j, y + i);
82 break;
83 default:
84 throw new Error("Unsupported splitting mode");
88 private void decodePredictionUnit(Vector mv, int x, int y) {
89 mv.pred_mode = modePrediction(x,y);
90 mv.pred_mode ^= ar[ARITH_PRED_MODE].decodeBit(Context.BLOCK_MODE_REF1);
91 if(par.num_refs > 1) {
92 mv.pred_mode ^= (ar[ARITH_PRED_MODE].decodeBit(Context.BLOCK_MODE_REF2) << 1);
94 if(mv.pred_mode == 0) {
95 int pred[] = new int[3];
96 dcPrediction(x,y,pred);
97 mv.dc[0] = pred[0] +
98 ar[ARITH_DC_0].decodeSint(Context.LUMA_DC_CONT_BIN1,
99 Context.LUMA_DC_VALUE,
100 Context.LUMA_DC_SIGN);
101 mv.dc[1] = pred[1] +
102 ar[ARITH_DC_1].decodeSint(Context.CHROMA1_DC_CONT_BIN1,
103 Context.CHROMA1_DC_VALUE,
104 Context.CHROMA1_DC_SIGN);
105 mv.dc[2] = pred[2] +
106 ar[ARITH_DC_2].decodeSint(Context.CHROMA2_DC_CONT_BIN1,
107 Context.CHROMA2_DC_VALUE,
108 Context.CHROMA2_DC_SIGN);
109 } else {
110 int pred_x, pred_y;
111 if(par.have_global_motion) {
112 int pred = globalPrediction(x,y);
113 pred ^= ar[ARITH_SUPERBLOCK].decodeBit(Context.GLOBAL_BLOCK);
114 mv.using_global = (pred == 0 ? false : true);
115 } else {
116 mv.using_global = false;
118 if(!mv.using_global) {
119 if((mv.pred_mode & 1) != 0) {
120 vectorPrediction(mv,x,y,1);
121 mv.dx[0] +=
122 ar[ARITH_REF1_X].decodeSint(Context.MV_REF1_H_CONT_BIN1,
123 Context.MV_REF1_H_VALUE,
124 Context.MV_REF1_H_SIGN);
125 mv.dy[0] +=
126 ar[ARITH_REF1_Y].decodeSint(Context.MV_REF1_V_CONT_BIN1,
127 Context.MV_REF1_V_VALUE,
128 Context.MV_REF1_V_SIGN);
130 if((mv.pred_mode & 2) != 0) {
131 vectorPrediction(mv, x, y, 2);
132 mv.dx[1] += ar[ARITH_REF2_X].decodeSint(Context.MV_REF2_H_CONT_BIN1,
133 Context.MV_REF2_H_VALUE,
134 Context.MV_REF2_H_SIGN);
135 mv.dy[1] += ar[ARITH_REF2_Y].decodeSint(Context.MV_REF2_V_CONT_BIN1,
136 Context.MV_REF2_V_VALUE,
137 Context.MV_REF2_V_SIGN);
142 mv.using_global = false;
143 mv.dx[0] = 0;
144 mv.dy[0] = 0;
145 mv.dx[1] = 0;
146 mv.dy[1] = 0;
149 public void render(Block out[], VideoFormat f) {
150 for(int k = 0; k < out.length; k++) {
151 initializeRender(k,f);
152 block = new Block(new Dimension(xblen, yblen));
153 for(int i = 0; i < par.num_refs; i++) {
154 tmp_ref[i] = refs[i].getComponent(k);
155 if(par.mv_precision > 0)
156 tmp_ref[i] = tmp_ref[i].upSample();
158 for(int j = 0; j < par.y_num_blocks; j++)
159 for(int i = 0; i < par.x_num_blocks; i++) {
160 predictBlock(out[k], i, j, k);
161 accumulateBlock(out[k], i*xbsep - xoffset,
162 j*ybsep - yoffset);
164 out[k].shiftOut(6,0);
165 out[k].clip(7);
169 private void initializeRender(int k, VideoFormat f) {
170 chroma_h_shift = f.chromaHShift();
171 chroma_v_shift = f.chromaVShift();
172 yblen = par.yblen_luma;
173 xblen = par.xblen_luma;
174 ybsep = par.ybsep_luma;
175 xbsep = par.xbsep_luma;
176 if(k != 0) {
177 yblen >>= chroma_v_shift;
178 ybsep >>= chroma_v_shift;
179 xbsep >>= chroma_h_shift;
180 xblen >>= chroma_h_shift;
182 yoffset = (yblen - ybsep) >> 1;
183 xoffset = (xblen - xbsep) >> 1;
184 /* initialize obmc weight */
185 weight_y = new short[yblen];
186 weight_x = new short[xblen];
187 obmc = new short[xblen*yblen];
188 for(int i = 0; i < xblen; i++) {
189 short wx;
190 if(xoffset == 0) {
191 wx = 8;
192 } else if( i < 2*xoffset) {
193 wx = Util.getRamp(i, xoffset);
194 } else if(xblen - 1 - i < 2*xoffset) {
195 wx = Util.getRamp(xblen - 1 - i, xoffset);
196 } else {
197 wx = 8;
199 weight_x[i] = wx;
201 for(int j = 0; j < yblen; j++) {
202 short wy;
203 if(yoffset == 0) {
204 wy = 8;
205 } else if(j < 2*yoffset) {
206 wy = Util.getRamp(j, yoffset);
207 } else if(yblen - 1 - j < 2*yoffset) {
208 wy = Util.getRamp(yblen - 1 - j, yoffset);
209 } else {
210 wy = 8;
212 weight_y[j] = wy;
217 private void dumpWeights() {
218 System.err.println("weight_x");
219 for(int i = 0; i < xblen; i++)
220 System.err.format("%d ", weight_x[i]);
221 System.err.println("\nweight_y");
222 for(int i = 0; i < yblen; i++)
223 System.err.format("%d ", weight_y[i]);
224 System.err.println("");
227 private void predictBlock(Block out, int i, int j, int k) {
228 int xstart = (i*xbsep) - xoffset,
229 ystart = (j*ybsep) - yoffset;
230 Vector mv = getVector(i,j);
231 if(mv.pred_mode == 0) {
232 for(int q = 0; j < yblen; j++)
233 for(int p = 0; i < xblen; i++)
234 block.set(p, q, (mv.dc[k]));
236 if(k != 0 && !mv.using_global)
237 mv = mv.scale(chroma_h_shift, chroma_v_shift);
238 for(int q = 0; q < yblen; q++) {
239 int y = ystart + q;
240 if(y < 0 || y > out.s.height - 1) continue;
241 for(int p = 0; p < xblen; p++) {
242 int x = xstart + p;
243 if(x < 0 || x > out.s.width - 1) continue;
244 block.set(p,q, predictPixel(mv, x, y, k));
249 private short predictPixel(Vector mv, int x, int y, int k) {
250 if(mv.using_global) {
251 for(int i = 0; i < par.num_refs; i++) {
252 par.global[i].getVector(mv, x, y, i);
254 if(k != 0)
255 mv = mv.scale(chroma_h_shift, chroma_v_shift);
257 short weight = (short)(par.picture_weight_1 + par.picture_weight_2);
258 short val = 0;
259 int px, py;
260 switch(mv.pred_mode) {
261 case 1:
262 px = (x << par.mv_precision) + mv.dx[0];
263 py = (y << par.mv_precision) + mv.dy[0];
264 val = (short)(weight*predictSubPixel(0, px, py));
265 break;
266 case 2:
267 px = (x << par.mv_precision) + mv.dx[1];
268 py = (y << par.mv_precision) + mv.dy[1];
269 val = (short)(weight*predictSubPixel(1, px, py));
270 break;
271 case 3:
272 px = (x << par.mv_precision) + mv.dx[0];
273 py = (y << par.mv_precision) + mv.dy[0];
274 val = (short)(par.picture_weight_1*predictSubPixel(0, px, py));
275 px = (x << par.mv_precision) + mv.dx[1];
276 py = (x << par.mv_precision) + mv.dy[1];
277 val += (short)(par.picture_weight_2*predictSubPixel(1, px, py));
278 default:
279 break;
281 return (short)Util.roundShift(val, par.picture_weight_bits);
284 private short predictSubPixel(int ref, int px, int py) {
285 if(par.mv_precision < 2) {
286 return tmp_ref[ref].real(px, py);
288 int prec = par.mv_precision;
289 int add = 1 << (prec - 1);
290 int hx = px >> (prec-1);
291 int hy = py >> (prec-1);
292 int rx = px - (hx << (prec-1));
293 int ry = py - (hy << (prec-1));
294 int w00,w01, w10, w11;
295 w00 = (add - rx)*(add - ry);
296 w01 = (add - rx)*ry;
297 w10 = rx*(add - ry);
298 w11 = rx*ry;
299 int val = w00*tmp_ref[ref].real(hx, hy) +
300 w01*tmp_ref[ref].real(hx + 1, hy) +
301 w10*tmp_ref[ref].real(hx, hy + 1) +
302 w11*tmp_ref[ref].real(hx + 1, hy + 1);
303 return (short)((val + (1 << (2*prec-3))) >> (2*prec - 2));
307 private void accumulateBlock(Block out, int x, int y) {
308 if(!edge(x,y)) {
309 for(int q = 0; q < yblen; q++) {
310 if(q + y < 0 || q + y >= out.s.height) continue;
311 int outLine = out.index(x, y + q);
312 int inLine = block.line(q);
313 for(int p = 0; p < xblen; p++) {
314 if(p + x < 0 || p + x >= out.s.width) continue;
315 out.d[outLine + p] +=
316 (short)(weight_x[p]*weight_y[q]*block.d[inLine+p]);
319 } else {
320 int w_x, w_y;
321 for(int q = 0; q < yblen; q++) {
322 if(q + y < 0 || q + y >= out.s.height) continue;
323 if((y < 0 && q < 2*yoffset) ||
324 (y >= par.y_num_blocks*ybsep - yoffset &&
325 yblen - 1 - q < 2*yoffset))
326 w_y = 8;
327 else
328 w_y = weight_y[q];
329 int outLine = out.index(x, y + q);
330 int inLine = block.line(q);
331 for(int p = 0; p < xblen; p++) {
332 if(p + x < 0 || p + x >= out.s.width) continue;
333 if((x < 0 && p < 2*xoffset) ||
334 (x >= par.x_num_blocks*xbsep - xoffset &&
335 xblen - 1 - p < 2*xoffset))
336 w_x = 8;
337 else
338 w_x = weight_x[p];
339 out.d[outLine + p] +=
340 (short)(w_x*w_y*block.d[inLine+p]);
348 private boolean edge(int x, int y) {
349 return (x < 0 || x >= par.x_num_blocks*xbsep - xoffset)
350 || (y < 0 || y >= par.y_num_blocks*ybsep - yoffset);
354 private int splitPrediction(int x, int y) {
355 if(y == 0) {
356 if(x == 0) {
357 return 0;
358 } else {
359 return vecs[x-4].split;
361 } else {
362 if(x == 0) {
363 return getVector(0, y - 4).split;
364 } else {
365 int sum = 0;
366 sum += getVector(x, y - 4).split;
367 sum += getVector(x - 4, y).split;
368 sum += getVector(x - 4, y - 4).split;
369 return (sum+1)/3;
374 private int modePrediction(int x, int y) {
375 if(y == 0) {
376 if(x == 0) {
377 return 0;
378 } else {
379 return vecs[x - 1].pred_mode;
381 } else {
382 if(x == 0) {
383 return getVector(0, y - 1).pred_mode;
384 } else {
385 int a,b,c;
386 a = getVector(x - 1, y).pred_mode;
387 b = getVector(x, y - 1).pred_mode;
388 c = getVector(x - 1, y - 1).pred_mode;
389 return (a&b)|(b&c)|(c&a);
394 private int globalPrediction(int x, int y) {
395 if(x == 0 && y == 0) {
396 return 0;
398 if(y == 0) {
399 return vecs[x-1].using_global ? 1 : 0;
401 if(x == 0) {
402 return getVector(0, y-1).using_global ? 1 : 0;
404 int sum = 0;
405 sum += getVector(x - 1, y).using_global ? 1 : 0;
406 sum += getVector(x, y - 1).using_global ? 1 : 0;
407 sum += getVector(x - 1, y - 1).using_global ? 1 : 0;
408 return (sum >= 2) ? 1 : 0;
411 private void vectorPrediction(Vector mv, int x, int y, int mode) {
412 int n = 0, vx[] = new int[3], vy[] = new int[3];
413 if(x > 0) {
414 Vector ov = getVector(x-1, y);
415 if(!ov.using_global &&
416 (ov.pred_mode & mode) != 0) {
417 vx[n] = ov.dx[mode-1];
418 vy[n] = ov.dx[mode-1];
419 n++;
422 if(y > 0) {
423 Vector ov = getVector(x, y-1);
424 if(!ov.using_global &&
425 (ov.pred_mode & mode) != 0) {
426 vx[n] = ov.dx[mode-1];
427 vy[n] = ov.dx[mode-1];
428 n++;
431 if(x > 0 && y > 0) {
432 Vector ov = getVector(x - 1, y - 1);
433 if(!ov.using_global &&
434 (ov.pred_mode & mode) != 0) {
435 vx[n] = ov.dx[mode-1];
436 vy[n] = ov.dy[mode-1];
437 n++;
440 switch(n) {
441 case 0:
442 mv.dx[mode-1] = 0;
443 mv.dy[mode-1] = 0;
444 break;
445 case 1:
446 mv.dx[mode-1] = vx[0];
447 mv.dy[mode-1] = vy[0];
448 break;
449 case 2:
450 mv.dx[mode-1] = (vx[0] + vx[1] + 1) >> 1;
451 mv.dy[mode-1] = (vy[0] + vy[1] + 1) >> 1;
452 break;
453 case 3:
454 mv.dx[mode-1] = Util.median(vx);
455 mv.dy[mode-1] = Util.median(vy);
456 break;
460 private void dcPrediction(int x, int y, int pred[]) {
461 for(int i = 0; i < 3; i++) {
462 int sum = 0, n = 0;
463 if(x > 0) {
464 Vector ov = getVector(x - 1, y);
465 if(ov.pred_mode == 0) {
466 sum += ov.dc[i];
467 n++;
470 if(y > 0) {
471 Vector ov = getVector(x, y - 1);
472 if(ov.pred_mode == 0) {
473 sum += ov.dc[i];
474 n++;
477 if(x > 0 && y > 0) {
478 Vector ov = getVector(x - 1, y - 1);
479 if(ov.pred_mode == 0) {
480 sum += ov.dc[i];
481 n++;
484 switch(n) {
485 case 0:
486 pred[i] = 0;
487 break;
488 case 1:
489 pred[i] = sum;
490 break;
491 case 2:
492 pred[i] = (sum+1)>>1;
493 break;
494 case 3:
495 pred[i] = (sum+1)/3;
496 break;
501 private final Vector getVector(int x, int y) {
502 int pos = x + y*par.x_num_blocks;
503 if(vecs[pos] == null) {
504 vecs[pos] = new Vector();
506 return vecs[pos];
509 private final void setVector(Vector mv, int x, int y) {
510 vecs[x + y*par.x_num_blocks] = mv;