diff --git a/scripts/alignment_heads_qwen3_asr_1.7B.json b/scripts/alignment_heads_qwen3_asr_1.7B.json new file mode 100644 index 0000000..52c7e09 --- /dev/null +++ b/scripts/alignment_heads_qwen3_asr_1.7B.json @@ -0,0 +1,3445 @@ +{ + "model": "Qwen/Qwen3-ASR-1.7B", + "language": "English", + "num_layers": 28, + "num_heads": 16, + "num_kv_heads": 8, + "num_samples": 100, + "total_alignable_tokens": 1125, + "ts_threshold": 0.1, + "ts_matrix": [ + [ + 0.10222222222222223, + 0.09333333333333334, + 0.10133333333333333, + 0.10755555555555556, + 0.056, + 0.06933333333333333, + 0.07644444444444444, + 0.07466666666666667, + 0.08533333333333333, + 0.09422222222222222, + 0.13155555555555556, + 0.1431111111111111, + 0.05333333333333334, + 0.041777777777777775, + 0.05422222222222222, + 0.07466666666666667 + ], + [ + 0.15733333333333333, + 0.15555555555555556, + 0.096, + 0.14044444444444446, + 0.064, + 0.056, + 0.06933333333333333, + 0.07377777777777778, + 0.3502222222222222, + 0.06311111111111112, + 0.08533333333333333, + 0.04711111111111111, + 0.03111111111111111, + 0.17155555555555554, + 0.13155555555555556, + 0.5191111111111111 + ], + [ + 0.06488888888888888, + 0.056, + 0.2577777777777778, + 0.6417777777777778, + 0.08177777777777778, + 0.06844444444444445, + 0.192, + 0.07288888888888889, + 0.3457777777777778, + 0.08711111111111111, + 0.6604444444444444, + 0.6666666666666666, + 0.08266666666666667, + 0.1111111111111111, + 0.36977777777777776, + 0.12355555555555556 + ], + [ + 0.11822222222222223, + 0.12622222222222224, + 0.16444444444444445, + 0.18488888888888888, + 0.256, + 0.088, + 0.09155555555555556, + 0.07555555555555556, + 0.11377777777777778, + 0.11733333333333333, + 0.6853333333333333, + 0.616, + 0.12533333333333332, + 0.26755555555555555, + 0.20266666666666666, + 0.20355555555555555 + ], + [ + 0.030222222222222223, + 0.034666666666666665, + 0.11644444444444445, + 0.10577777777777778, + 0.11911111111111111, + 0.06933333333333333, + 0.029333333333333333, + 0.09333333333333334, + 0.12266666666666666, + 0.09244444444444444, + 0.3831111111111111, + 0.20533333333333334, + 0.43555555555555553, + 0.6542222222222223, + 0.08266666666666667, + 0.25955555555555554 + ], + [ + 0.10755555555555556, + 0.10133333333333333, + 0.08533333333333333, + 0.07022222222222223, + 0.13866666666666666, + 0.22133333333333333, + 0.11911111111111111, + 0.12622222222222224, + 0.1288888888888889, + 0.12977777777777777, + 0.44355555555555554, + 0.12266666666666666, + 0.05422222222222222, + 0.04888888888888889, + 0.152, + 0.32266666666666666 + ], + [ + 0.25244444444444447, + 0.21422222222222223, + 0.08088888888888889, + 0.12444444444444444, + 0.17155555555555554, + 0.13955555555555554, + 0.7288888888888889, + 0.7315555555555555, + 0.03288888888888889, + 0.24888888888888888, + 0.7146666666666667, + 0.7031111111111111, + 0.6417777777777778, + 0.6888888888888889, + 0.18666666666666668, + 0.1511111111111111 + ], + [ + 0.13422222222222221, + 0.03822222222222222, + 0.07022222222222223, + 0.08177777777777778, + 0.29155555555555557, + 0.1368888888888889, + 0.16444444444444445, + 0.07733333333333334, + 0.09244444444444444, + 0.030222222222222223, + 0.13155555555555556, + 0.14844444444444443, + 0.12444444444444444, + 0.22755555555555557, + 0.12622222222222224, + 0.17244444444444446 + ], + [ + 0.12266666666666666, + 0.6008888888888889, + 0.14844444444444443, + 0.06577777777777778, + 0.6488888888888888, + 0.3546666666666667, + 0.23644444444444446, + 0.296, + 0.10311111111111111, + 0.13155555555555556, + 0.17422222222222222, + 0.14666666666666667, + 0.136, + 0.1991111111111111, + 0.3111111111111111, + 0.09333333333333334 + ], + [ + 0.1902222222222222, + 0.03822222222222222, + 0.1608888888888889, + 0.09155555555555556, + 0.18844444444444444, + 0.19466666666666665, + 0.04533333333333334, + 0.1671111111111111, + 0.22844444444444445, + 0.23644444444444446, + 0.17333333333333334, + 0.11555555555555555, + 0.49422222222222223, + 0.41244444444444445, + 0.12977777777777777, + 0.018666666666666668 + ], + [ + 0.028444444444444446, + 0.04622222222222222, + 0.18222222222222223, + 0.25066666666666665, + 0.17866666666666667, + 0.32266666666666666, + 0.051555555555555556, + 0.07822222222222222, + 0.1448888888888889, + 0.152, + 0.0791111111111111, + 0.15733333333333333, + 0.1111111111111111, + 0.14844444444444443, + 0.04711111111111111, + 0.10044444444444445 + ], + [ + 0.18577777777777776, + 0.22044444444444444, + 0.7573333333333333, + 0.7182222222222222, + 0.11288888888888889, + 0.168, + 0.18044444444444444, + 0.2577777777777778, + 0.18933333333333333, + 0.11377777777777778, + 0.2871111111111111, + 0.6168888888888889, + 0.7093333333333334, + 0.7484444444444445, + 0.050666666666666665, + 0.11288888888888889 + ], + [ + 0.344, + 0.37155555555555553, + 0.16977777777777778, + 0.2551111111111111, + 0.0791111111111111, + 0.12, + 0.5511111111111111, + 0.07555555555555556, + 0.31733333333333336, + 0.09688888888888889, + 0.23733333333333334, + 0.06666666666666667, + 0.17155555555555554, + 0.10844444444444444, + 0.21244444444444444, + 0.20355555555555555 + ], + [ + 0.6124444444444445, + 0.192, + 0.18044444444444444, + 0.1288888888888889, + 0.3848888888888889, + 0.136, + 0.48533333333333334, + 0.5022222222222222, + 0.034666666666666665, + 0.04888888888888889, + 0.088, + 0.6702222222222223, + 0.025777777777777778, + 0.03822222222222222, + 0.5964444444444444, + 0.4231111111111111 + ], + [ + 0.19377777777777777, + 0.09066666666666667, + 0.16355555555555557, + 0.07466666666666667, + 0.051555555555555556, + 0.2222222222222222, + 0.18666666666666668, + 0.14666666666666667, + 0.064, + 0.07822222222222222, + 0.18755555555555556, + 0.23644444444444446, + 0.42133333333333334, + 0.21066666666666667, + 0.7351111111111112, + 0.7164444444444444 + ], + [ + 0.12622222222222224, + 0.168, + 0.1751111111111111, + 0.152, + 0.18488888888888888, + 0.1751111111111111, + 0.21866666666666668, + 0.10933333333333334, + 0.07555555555555556, + 0.16533333333333333, + 0.3111111111111111, + 0.16177777777777777, + 0.04088888888888889, + 0.037333333333333336, + 0.18488888888888888, + 0.11466666666666667 + ], + [ + 0.05333333333333334, + 0.041777777777777775, + 0.11377777777777778, + 0.15911111111111112, + 0.11555555555555555, + 0.13333333333333333, + 0.16444444444444445, + 0.4817777777777778, + 0.25422222222222224, + 0.264, + 0.648, + 0.5493333333333333, + 0.2995555555555556, + 0.4017777777777778, + 0.7573333333333333, + 0.6977777777777778 + ], + [ + 0.25866666666666666, + 0.25955555555555554, + 0.2328888888888889, + 0.18133333333333335, + 0.08444444444444445, + 0.058666666666666666, + 0.042666666666666665, + 0.22933333333333333, + 0.34044444444444444, + 0.24533333333333332, + 0.23822222222222222, + 0.18577777777777776, + 0.248, + 0.4017777777777778, + 0.11644444444444445, + 0.112 + ], + [ + 0.07377777777777778, + 0.07733333333333334, + 0.37244444444444447, + 0.6417777777777778, + 0.27466666666666667, + 0.6515555555555556, + 0.18222222222222223, + 0.16177777777777777, + 0.11377777777777778, + 0.07466666666666667, + 0.37777777777777777, + 0.1991111111111111, + 0.042666666666666665, + 0.19733333333333333, + 0.08711111111111111, + 0.2 + ], + [ + 0.16977777777777778, + 0.17066666666666666, + 0.31022222222222223, + 0.544, + 0.4391111111111111, + 0.6391111111111111, + 0.17066666666666666, + 0.712, + 0.4311111111111111, + 0.5022222222222222, + 0.07466666666666667, + 0.08711111111111111, + 0.3662222222222222, + 0.4017777777777778, + 0.04888888888888889, + 0.08266666666666667 + ], + [ + 0.10044444444444445, + 0.10844444444444444, + 0.15911111111111112, + 0.7644444444444445, + 0.3448888888888889, + 0.16177777777777777, + 0.3635555555555556, + 0.5031111111111111, + 0.31733333333333336, + 0.06933333333333333, + 0.5022222222222222, + 0.5742222222222222, + 0.3297777777777778, + 0.23644444444444446, + 0.6551111111111111, + 0.5831111111111111 + ], + [ + 0.5146666666666667, + 0.5031111111111111, + 0.112, + 0.07111111111111111, + 0.2391111111111111, + 0.15555555555555556, + 0.24266666666666667, + 0.18844444444444444, + 0.7386666666666667, + 0.7617777777777778, + 0.25066666666666665, + 0.352, + 0.5457777777777778, + 0.4088888888888889, + 0.3128888888888889, + 0.36177777777777775 + ], + [ + 0.21155555555555555, + 0.26666666666666666, + 0.10488888888888889, + 0.06222222222222222, + 0.288, + 0.25066666666666665, + 0.2995555555555556, + 0.6515555555555556, + 0.5955555555555555, + 0.6302222222222222, + 0.24977777777777777, + 0.2568888888888889, + 0.6195555555555555, + 0.5431111111111111, + 0.23466666666666666, + 0.08622222222222223 + ], + [ + 0.48977777777777776, + 0.5102222222222222, + 0.05688888888888889, + 0.06311111111111112, + 0.6222222222222222, + 0.4142222222222222, + 0.24888888888888888, + 0.6462222222222223, + 0.06488888888888888, + 0.1608888888888889, + 0.3537777777777778, + 0.31822222222222224, + 0.20177777777777778, + 0.1448888888888889, + 0.6275555555555555, + 0.6044444444444445 + ], + [ + 0.036444444444444446, + 0.048, + 0.06222222222222222, + 0.07377777777777778, + 0.42933333333333334, + 0.6257777777777778, + 0.5306666666666666, + 0.6008888888888889, + 0.09066666666666667, + 0.072, + 0.5493333333333333, + 0.5804444444444444, + 0.5866666666666667, + 0.5937777777777777, + 0.6257777777777778, + 0.6204444444444445 + ], + [ + 0.09066666666666667, + 0.11733333333333333, + 0.059555555555555556, + 0.07022222222222223, + 0.5982222222222222, + 0.648, + 0.5875555555555556, + 0.5964444444444444, + 0.352, + 0.4888888888888889, + 0.5715555555555556, + 0.6035555555555555, + 0.5875555555555556, + 0.5804444444444444, + 0.5688888888888889, + 0.3546666666666667 + ], + [ + 0.376, + 0.3217777777777778, + 0.5786666666666667, + 0.5466666666666666, + 0.5475555555555556, + 0.5155555555555555, + 0.1688888888888889, + 0.5528888888888889, + 0.6142222222222222, + 0.21511111111111111, + 0.08622222222222223, + 0.20533333333333334, + 0.13066666666666665, + 0.10222222222222223, + 0.5511111111111111, + 0.4951111111111111 + ], + [ + 0.08177777777777778, + 0.10044444444444445, + 0.08711111111111111, + 0.08888888888888889, + 0.08533333333333333, + 0.056, + 0.15466666666666667, + 0.07377777777777778, + 0.04888888888888889, + 0.07022222222222223, + 0.10222222222222223, + 0.0951111111111111, + 0.08088888888888889, + 0.06311111111111112, + 0.09688888888888889, + 0.07111111111111111 + ] + ], + "alignment_heads": [ + { + "layer": 20, + "head": 3, + "ts": 0.7644 + }, + { + "layer": 21, + "head": 9, + "ts": 0.7618 + }, + { + "layer": 11, + "head": 2, + "ts": 0.7573 + }, + { + "layer": 16, + "head": 14, + "ts": 0.7573 + }, + { + "layer": 11, + "head": 13, + "ts": 0.7484 + }, + { + "layer": 21, + "head": 8, + "ts": 0.7387 + }, + { + "layer": 14, + "head": 14, + "ts": 0.7351 + }, + { + "layer": 6, + "head": 7, + "ts": 0.7316 + }, + { + "layer": 6, + "head": 6, + "ts": 0.7289 + }, + { + "layer": 11, + "head": 3, + "ts": 0.7182 + }, + { + "layer": 14, + "head": 15, + "ts": 0.7164 + }, + { + "layer": 6, + "head": 10, + "ts": 0.7147 + }, + { + "layer": 19, + "head": 7, + "ts": 0.712 + }, + { + "layer": 11, + "head": 12, + "ts": 0.7093 + }, + { + "layer": 6, + "head": 11, + "ts": 0.7031 + }, + { + "layer": 16, + "head": 15, + "ts": 0.6978 + }, + { + "layer": 6, + "head": 13, + "ts": 0.6889 + }, + { + "layer": 3, + "head": 10, + "ts": 0.6853 + }, + { + "layer": 13, + "head": 11, + "ts": 0.6702 + }, + { + "layer": 2, + "head": 11, + "ts": 0.6667 + }, + { + "layer": 2, + "head": 10, + "ts": 0.6604 + }, + { + "layer": 20, + "head": 14, + "ts": 0.6551 + }, + { + "layer": 4, + "head": 13, + "ts": 0.6542 + }, + { + "layer": 18, + "head": 5, + "ts": 0.6516 + }, + { + "layer": 22, + "head": 7, + "ts": 0.6516 + }, + { + "layer": 8, + "head": 4, + "ts": 0.6489 + }, + { + "layer": 16, + "head": 10, + "ts": 0.648 + }, + { + "layer": 25, + "head": 5, + "ts": 0.648 + }, + { + "layer": 23, + "head": 7, + "ts": 0.6462 + }, + { + "layer": 2, + "head": 3, + "ts": 0.6418 + }, + { + "layer": 6, + "head": 12, + "ts": 0.6418 + }, + { + "layer": 18, + "head": 3, + "ts": 0.6418 + }, + { + "layer": 19, + "head": 5, + "ts": 0.6391 + }, + { + "layer": 22, + "head": 9, + "ts": 0.6302 + }, + { + "layer": 23, + "head": 14, + "ts": 0.6276 + }, + { + "layer": 24, + "head": 5, + "ts": 0.6258 + }, + { + "layer": 24, + "head": 14, + "ts": 0.6258 + }, + { + "layer": 23, + "head": 4, + "ts": 0.6222 + }, + { + "layer": 24, + "head": 15, + "ts": 0.6204 + }, + { + "layer": 22, + "head": 12, + "ts": 0.6196 + }, + { + "layer": 11, + "head": 11, + "ts": 0.6169 + }, + { + "layer": 3, + "head": 11, + "ts": 0.616 + }, + { + "layer": 26, + "head": 8, + "ts": 0.6142 + }, + { + "layer": 13, + "head": 0, + "ts": 0.6124 + }, + { + "layer": 23, + "head": 15, + "ts": 0.6044 + }, + { + "layer": 25, + "head": 11, + "ts": 0.6036 + }, + { + "layer": 8, + "head": 1, + "ts": 0.6009 + }, + { + "layer": 24, + "head": 7, + "ts": 0.6009 + }, + { + "layer": 25, + "head": 4, + "ts": 0.5982 + }, + { + "layer": 13, + "head": 14, + "ts": 0.5964 + }, + { + "layer": 25, + "head": 7, + "ts": 0.5964 + }, + { + "layer": 22, + "head": 8, + "ts": 0.5956 + }, + { + "layer": 24, + "head": 13, + "ts": 0.5938 + }, + { + "layer": 25, + "head": 6, + "ts": 0.5876 + }, + { + "layer": 25, + "head": 12, + "ts": 0.5876 + }, + { + "layer": 24, + "head": 12, + "ts": 0.5867 + }, + { + "layer": 20, + "head": 15, + "ts": 0.5831 + }, + { + "layer": 24, + "head": 11, + "ts": 0.5804 + }, + { + "layer": 25, + "head": 13, + "ts": 0.5804 + }, + { + "layer": 26, + "head": 2, + "ts": 0.5787 + }, + { + "layer": 20, + "head": 11, + "ts": 0.5742 + }, + { + "layer": 25, + "head": 10, + "ts": 0.5716 + }, + { + "layer": 25, + "head": 14, + "ts": 0.5689 + }, + { + "layer": 26, + "head": 7, + "ts": 0.5529 + }, + { + "layer": 12, + "head": 6, + "ts": 0.5511 + }, + { + "layer": 26, + "head": 14, + "ts": 0.5511 + }, + { + "layer": 16, + "head": 11, + "ts": 0.5493 + }, + { + "layer": 24, + "head": 10, + "ts": 0.5493 + }, + { + "layer": 26, + "head": 4, + "ts": 0.5476 + }, + { + "layer": 26, + "head": 3, + "ts": 0.5467 + }, + { + "layer": 21, + "head": 12, + "ts": 0.5458 + }, + { + "layer": 19, + "head": 3, + "ts": 0.544 + }, + { + "layer": 22, + "head": 13, + "ts": 0.5431 + }, + { + "layer": 24, + "head": 6, + "ts": 0.5307 + }, + { + "layer": 1, + "head": 15, + "ts": 0.5191 + }, + { + "layer": 26, + "head": 5, + "ts": 0.5156 + }, + { + "layer": 21, + "head": 0, + "ts": 0.5147 + }, + { + "layer": 23, + "head": 1, + "ts": 0.5102 + }, + { + "layer": 20, + "head": 7, + "ts": 0.5031 + }, + { + "layer": 21, + "head": 1, + "ts": 0.5031 + }, + { + "layer": 13, + "head": 7, + "ts": 0.5022 + }, + { + "layer": 19, + "head": 9, + "ts": 0.5022 + }, + { + "layer": 20, + "head": 10, + "ts": 0.5022 + }, + { + "layer": 26, + "head": 15, + "ts": 0.4951 + }, + { + "layer": 9, + "head": 12, + "ts": 0.4942 + }, + { + "layer": 23, + "head": 0, + "ts": 0.4898 + }, + { + "layer": 25, + "head": 9, + "ts": 0.4889 + }, + { + "layer": 13, + "head": 6, + "ts": 0.4853 + }, + { + "layer": 16, + "head": 7, + "ts": 0.4818 + }, + { + "layer": 5, + "head": 10, + "ts": 0.4436 + }, + { + "layer": 19, + "head": 4, + "ts": 0.4391 + }, + { + "layer": 4, + "head": 12, + "ts": 0.4356 + }, + { + "layer": 19, + "head": 8, + "ts": 0.4311 + }, + { + "layer": 24, + "head": 4, + "ts": 0.4293 + }, + { + "layer": 13, + "head": 15, + "ts": 0.4231 + }, + { + "layer": 14, + "head": 12, + "ts": 0.4213 + }, + { + "layer": 23, + "head": 5, + "ts": 0.4142 + }, + { + "layer": 9, + "head": 13, + "ts": 0.4124 + }, + { + "layer": 21, + "head": 13, + "ts": 0.4089 + }, + { + "layer": 16, + "head": 13, + "ts": 0.4018 + }, + { + "layer": 17, + "head": 13, + "ts": 0.4018 + }, + { + "layer": 19, + "head": 13, + "ts": 0.4018 + }, + { + "layer": 13, + "head": 4, + "ts": 0.3849 + }, + { + "layer": 4, + "head": 10, + "ts": 0.3831 + }, + { + "layer": 18, + "head": 10, + "ts": 0.3778 + }, + { + "layer": 26, + "head": 0, + "ts": 0.376 + }, + { + "layer": 18, + "head": 2, + "ts": 0.3724 + }, + { + "layer": 12, + "head": 1, + "ts": 0.3716 + }, + { + "layer": 2, + "head": 14, + "ts": 0.3698 + }, + { + "layer": 19, + "head": 12, + "ts": 0.3662 + }, + { + "layer": 20, + "head": 6, + "ts": 0.3636 + }, + { + "layer": 21, + "head": 15, + "ts": 0.3618 + }, + { + "layer": 8, + "head": 5, + "ts": 0.3547 + }, + { + "layer": 25, + "head": 15, + "ts": 0.3547 + }, + { + "layer": 23, + "head": 10, + "ts": 0.3538 + }, + { + "layer": 21, + "head": 11, + "ts": 0.352 + }, + { + "layer": 25, + "head": 8, + "ts": 0.352 + }, + { + "layer": 1, + "head": 8, + "ts": 0.3502 + }, + { + "layer": 2, + "head": 8, + "ts": 0.3458 + }, + { + "layer": 20, + "head": 4, + "ts": 0.3449 + }, + { + "layer": 12, + "head": 0, + "ts": 0.344 + }, + { + "layer": 17, + "head": 8, + "ts": 0.3404 + }, + { + "layer": 20, + "head": 12, + "ts": 0.3298 + }, + { + "layer": 5, + "head": 15, + "ts": 0.3227 + }, + { + "layer": 10, + "head": 5, + "ts": 0.3227 + }, + { + "layer": 26, + "head": 1, + "ts": 0.3218 + }, + { + "layer": 23, + "head": 11, + "ts": 0.3182 + }, + { + "layer": 12, + "head": 8, + "ts": 0.3173 + }, + { + "layer": 20, + "head": 8, + "ts": 0.3173 + }, + { + "layer": 21, + "head": 14, + "ts": 0.3129 + }, + { + "layer": 8, + "head": 14, + "ts": 0.3111 + }, + { + "layer": 15, + "head": 10, + "ts": 0.3111 + }, + { + "layer": 19, + "head": 2, + "ts": 0.3102 + }, + { + "layer": 16, + "head": 12, + "ts": 0.2996 + }, + { + "layer": 22, + "head": 6, + "ts": 0.2996 + }, + { + "layer": 8, + "head": 7, + "ts": 0.296 + }, + { + "layer": 7, + "head": 4, + "ts": 0.2916 + }, + { + "layer": 22, + "head": 4, + "ts": 0.288 + }, + { + "layer": 11, + "head": 10, + "ts": 0.2871 + }, + { + "layer": 18, + "head": 4, + "ts": 0.2747 + }, + { + "layer": 3, + "head": 13, + "ts": 0.2676 + }, + { + "layer": 22, + "head": 1, + "ts": 0.2667 + }, + { + "layer": 16, + "head": 9, + "ts": 0.264 + }, + { + "layer": 4, + "head": 15, + "ts": 0.2596 + }, + { + "layer": 17, + "head": 1, + "ts": 0.2596 + }, + { + "layer": 17, + "head": 0, + "ts": 0.2587 + }, + { + "layer": 2, + "head": 2, + "ts": 0.2578 + }, + { + "layer": 11, + "head": 7, + "ts": 0.2578 + }, + { + "layer": 22, + "head": 11, + "ts": 0.2569 + }, + { + "layer": 3, + "head": 4, + "ts": 0.256 + }, + { + "layer": 12, + "head": 3, + "ts": 0.2551 + }, + { + "layer": 16, + "head": 8, + "ts": 0.2542 + }, + { + "layer": 6, + "head": 0, + "ts": 0.2524 + }, + { + "layer": 10, + "head": 3, + "ts": 0.2507 + }, + { + "layer": 21, + "head": 10, + "ts": 0.2507 + }, + { + "layer": 22, + "head": 5, + "ts": 0.2507 + }, + { + "layer": 22, + "head": 10, + "ts": 0.2498 + }, + { + "layer": 6, + "head": 9, + "ts": 0.2489 + }, + { + "layer": 23, + "head": 6, + "ts": 0.2489 + }, + { + "layer": 17, + "head": 12, + "ts": 0.248 + }, + { + "layer": 17, + "head": 9, + "ts": 0.2453 + }, + { + "layer": 21, + "head": 6, + "ts": 0.2427 + }, + { + "layer": 21, + "head": 4, + "ts": 0.2391 + }, + { + "layer": 17, + "head": 10, + "ts": 0.2382 + }, + { + "layer": 12, + "head": 10, + "ts": 0.2373 + }, + { + "layer": 8, + "head": 6, + "ts": 0.2364 + }, + { + "layer": 9, + "head": 9, + "ts": 0.2364 + }, + { + "layer": 14, + "head": 11, + "ts": 0.2364 + }, + { + "layer": 20, + "head": 13, + "ts": 0.2364 + }, + { + "layer": 22, + "head": 14, + "ts": 0.2347 + }, + { + "layer": 17, + "head": 2, + "ts": 0.2329 + }, + { + "layer": 17, + "head": 7, + "ts": 0.2293 + }, + { + "layer": 9, + "head": 8, + "ts": 0.2284 + }, + { + "layer": 7, + "head": 13, + "ts": 0.2276 + }, + { + "layer": 14, + "head": 5, + "ts": 0.2222 + }, + { + "layer": 5, + "head": 5, + "ts": 0.2213 + }, + { + "layer": 11, + "head": 1, + "ts": 0.2204 + }, + { + "layer": 15, + "head": 6, + "ts": 0.2187 + }, + { + "layer": 26, + "head": 9, + "ts": 0.2151 + }, + { + "layer": 6, + "head": 1, + "ts": 0.2142 + }, + { + "layer": 12, + "head": 14, + "ts": 0.2124 + }, + { + "layer": 22, + "head": 0, + "ts": 0.2116 + }, + { + "layer": 14, + "head": 13, + "ts": 0.2107 + }, + { + "layer": 4, + "head": 11, + "ts": 0.2053 + }, + { + "layer": 26, + "head": 11, + "ts": 0.2053 + }, + { + "layer": 3, + "head": 15, + "ts": 0.2036 + }, + { + "layer": 12, + "head": 15, + "ts": 0.2036 + }, + { + "layer": 3, + "head": 14, + "ts": 0.2027 + }, + { + "layer": 23, + "head": 12, + "ts": 0.2018 + }, + { + "layer": 18, + "head": 15, + "ts": 0.2 + }, + { + "layer": 8, + "head": 13, + "ts": 0.1991 + }, + { + "layer": 18, + "head": 11, + "ts": 0.1991 + }, + { + "layer": 18, + "head": 13, + "ts": 0.1973 + }, + { + "layer": 9, + "head": 5, + "ts": 0.1947 + }, + { + "layer": 14, + "head": 0, + "ts": 0.1938 + }, + { + "layer": 2, + "head": 6, + "ts": 0.192 + }, + { + "layer": 13, + "head": 1, + "ts": 0.192 + }, + { + "layer": 9, + "head": 0, + "ts": 0.1902 + }, + { + "layer": 11, + "head": 8, + "ts": 0.1893 + }, + { + "layer": 9, + "head": 4, + "ts": 0.1884 + }, + { + "layer": 21, + "head": 7, + "ts": 0.1884 + }, + { + "layer": 14, + "head": 10, + "ts": 0.1876 + }, + { + "layer": 6, + "head": 14, + "ts": 0.1867 + }, + { + "layer": 14, + "head": 6, + "ts": 0.1867 + }, + { + "layer": 11, + "head": 0, + "ts": 0.1858 + }, + { + "layer": 17, + "head": 11, + "ts": 0.1858 + }, + { + "layer": 3, + "head": 3, + "ts": 0.1849 + }, + { + "layer": 15, + "head": 4, + "ts": 0.1849 + }, + { + "layer": 15, + "head": 14, + "ts": 0.1849 + }, + { + "layer": 10, + "head": 2, + "ts": 0.1822 + }, + { + "layer": 18, + "head": 6, + "ts": 0.1822 + }, + { + "layer": 17, + "head": 3, + "ts": 0.1813 + }, + { + "layer": 11, + "head": 6, + "ts": 0.1804 + }, + { + "layer": 13, + "head": 2, + "ts": 0.1804 + }, + { + "layer": 10, + "head": 4, + "ts": 0.1787 + }, + { + "layer": 15, + "head": 2, + "ts": 0.1751 + }, + { + "layer": 15, + "head": 5, + "ts": 0.1751 + }, + { + "layer": 8, + "head": 10, + "ts": 0.1742 + }, + { + "layer": 9, + "head": 10, + "ts": 0.1733 + }, + { + "layer": 7, + "head": 15, + "ts": 0.1724 + }, + { + "layer": 1, + "head": 13, + "ts": 0.1716 + }, + { + "layer": 6, + "head": 4, + "ts": 0.1716 + }, + { + "layer": 12, + "head": 12, + "ts": 0.1716 + }, + { + "layer": 19, + "head": 1, + "ts": 0.1707 + }, + { + "layer": 19, + "head": 6, + "ts": 0.1707 + }, + { + "layer": 12, + "head": 2, + "ts": 0.1698 + }, + { + "layer": 19, + "head": 0, + "ts": 0.1698 + }, + { + "layer": 26, + "head": 6, + "ts": 0.1689 + }, + { + "layer": 11, + "head": 5, + "ts": 0.168 + }, + { + "layer": 15, + "head": 1, + "ts": 0.168 + }, + { + "layer": 9, + "head": 7, + "ts": 0.1671 + }, + { + "layer": 15, + "head": 9, + "ts": 0.1653 + }, + { + "layer": 3, + "head": 2, + "ts": 0.1644 + }, + { + "layer": 7, + "head": 6, + "ts": 0.1644 + }, + { + "layer": 16, + "head": 6, + "ts": 0.1644 + }, + { + "layer": 14, + "head": 2, + "ts": 0.1636 + }, + { + "layer": 15, + "head": 11, + "ts": 0.1618 + }, + { + "layer": 18, + "head": 7, + "ts": 0.1618 + }, + { + "layer": 20, + "head": 5, + "ts": 0.1618 + }, + { + "layer": 9, + "head": 2, + "ts": 0.1609 + }, + { + "layer": 23, + "head": 9, + "ts": 0.1609 + }, + { + "layer": 16, + "head": 3, + "ts": 0.1591 + }, + { + "layer": 20, + "head": 2, + "ts": 0.1591 + }, + { + "layer": 1, + "head": 0, + "ts": 0.1573 + }, + { + "layer": 10, + "head": 11, + "ts": 0.1573 + }, + { + "layer": 1, + "head": 1, + "ts": 0.1556 + }, + { + "layer": 21, + "head": 5, + "ts": 0.1556 + }, + { + "layer": 27, + "head": 6, + "ts": 0.1547 + }, + { + "layer": 5, + "head": 14, + "ts": 0.152 + }, + { + "layer": 10, + "head": 9, + "ts": 0.152 + }, + { + "layer": 15, + "head": 3, + "ts": 0.152 + }, + { + "layer": 6, + "head": 15, + "ts": 0.1511 + }, + { + "layer": 7, + "head": 11, + "ts": 0.1484 + }, + { + "layer": 8, + "head": 2, + "ts": 0.1484 + }, + { + "layer": 10, + "head": 13, + "ts": 0.1484 + }, + { + "layer": 8, + "head": 11, + "ts": 0.1467 + }, + { + "layer": 14, + "head": 7, + "ts": 0.1467 + }, + { + "layer": 10, + "head": 8, + "ts": 0.1449 + }, + { + "layer": 23, + "head": 13, + "ts": 0.1449 + }, + { + "layer": 0, + "head": 11, + "ts": 0.1431 + }, + { + "layer": 1, + "head": 3, + "ts": 0.1404 + }, + { + "layer": 6, + "head": 5, + "ts": 0.1396 + }, + { + "layer": 5, + "head": 4, + "ts": 0.1387 + }, + { + "layer": 7, + "head": 5, + "ts": 0.1369 + }, + { + "layer": 8, + "head": 12, + "ts": 0.136 + }, + { + "layer": 13, + "head": 5, + "ts": 0.136 + }, + { + "layer": 7, + "head": 0, + "ts": 0.1342 + }, + { + "layer": 16, + "head": 5, + "ts": 0.1333 + }, + { + "layer": 0, + "head": 10, + "ts": 0.1316 + }, + { + "layer": 1, + "head": 14, + "ts": 0.1316 + }, + { + "layer": 7, + "head": 10, + "ts": 0.1316 + }, + { + "layer": 8, + "head": 9, + "ts": 0.1316 + }, + { + "layer": 26, + "head": 12, + "ts": 0.1307 + }, + { + "layer": 5, + "head": 9, + "ts": 0.1298 + }, + { + "layer": 9, + "head": 14, + "ts": 0.1298 + }, + { + "layer": 5, + "head": 8, + "ts": 0.1289 + }, + { + "layer": 13, + "head": 3, + "ts": 0.1289 + }, + { + "layer": 3, + "head": 1, + "ts": 0.1262 + }, + { + "layer": 5, + "head": 7, + "ts": 0.1262 + }, + { + "layer": 7, + "head": 14, + "ts": 0.1262 + }, + { + "layer": 15, + "head": 0, + "ts": 0.1262 + }, + { + "layer": 3, + "head": 12, + "ts": 0.1253 + }, + { + "layer": 6, + "head": 3, + "ts": 0.1244 + }, + { + "layer": 7, + "head": 12, + "ts": 0.1244 + }, + { + "layer": 2, + "head": 15, + "ts": 0.1236 + }, + { + "layer": 4, + "head": 8, + "ts": 0.1227 + }, + { + "layer": 5, + "head": 11, + "ts": 0.1227 + }, + { + "layer": 8, + "head": 0, + "ts": 0.1227 + }, + { + "layer": 12, + "head": 5, + "ts": 0.12 + }, + { + "layer": 4, + "head": 4, + "ts": 0.1191 + }, + { + "layer": 5, + "head": 6, + "ts": 0.1191 + }, + { + "layer": 3, + "head": 0, + "ts": 0.1182 + }, + { + "layer": 3, + "head": 9, + "ts": 0.1173 + }, + { + "layer": 25, + "head": 1, + "ts": 0.1173 + }, + { + "layer": 4, + "head": 2, + "ts": 0.1164 + }, + { + "layer": 17, + "head": 14, + "ts": 0.1164 + }, + { + "layer": 9, + "head": 11, + "ts": 0.1156 + }, + { + "layer": 16, + "head": 4, + "ts": 0.1156 + }, + { + "layer": 15, + "head": 15, + "ts": 0.1147 + }, + { + "layer": 3, + "head": 8, + "ts": 0.1138 + }, + { + "layer": 11, + "head": 9, + "ts": 0.1138 + }, + { + "layer": 16, + "head": 2, + "ts": 0.1138 + }, + { + "layer": 18, + "head": 8, + "ts": 0.1138 + }, + { + "layer": 11, + "head": 4, + "ts": 0.1129 + }, + { + "layer": 11, + "head": 15, + "ts": 0.1129 + }, + { + "layer": 17, + "head": 15, + "ts": 0.112 + }, + { + "layer": 21, + "head": 2, + "ts": 0.112 + }, + { + "layer": 2, + "head": 13, + "ts": 0.1111 + }, + { + "layer": 10, + "head": 12, + "ts": 0.1111 + }, + { + "layer": 15, + "head": 7, + "ts": 0.1093 + }, + { + "layer": 12, + "head": 13, + "ts": 0.1084 + }, + { + "layer": 20, + "head": 1, + "ts": 0.1084 + }, + { + "layer": 0, + "head": 3, + "ts": 0.1076 + }, + { + "layer": 5, + "head": 0, + "ts": 0.1076 + }, + { + "layer": 4, + "head": 3, + "ts": 0.1058 + }, + { + "layer": 22, + "head": 2, + "ts": 0.1049 + }, + { + "layer": 8, + "head": 8, + "ts": 0.1031 + }, + { + "layer": 0, + "head": 0, + "ts": 0.1022 + }, + { + "layer": 26, + "head": 13, + "ts": 0.1022 + }, + { + "layer": 27, + "head": 10, + "ts": 0.1022 + }, + { + "layer": 0, + "head": 2, + "ts": 0.1013 + }, + { + "layer": 5, + "head": 1, + "ts": 0.1013 + }, + { + "layer": 10, + "head": 15, + "ts": 0.1004 + }, + { + "layer": 20, + "head": 0, + "ts": 0.1004 + }, + { + "layer": 27, + "head": 1, + "ts": 0.1004 + } + ], + "alignment_heads_compact": [ + [ + 20, + 3 + ], + [ + 21, + 9 + ], + [ + 11, + 2 + ], + [ + 16, + 14 + ], + [ + 11, + 13 + ], + [ + 21, + 8 + ], + [ + 14, + 14 + ], + [ + 6, + 7 + ], + [ + 6, + 6 + ], + [ + 11, + 3 + ], + [ + 14, + 15 + ], + [ + 6, + 10 + ], + [ + 19, + 7 + ], + [ + 11, + 12 + ], + [ + 6, + 11 + ], + [ + 16, + 15 + ], + [ + 6, + 13 + ], + [ + 3, + 10 + ], + [ + 13, + 11 + ], + [ + 2, + 11 + ], + [ + 2, + 10 + ], + [ + 20, + 14 + ], + [ + 4, + 13 + ], + [ + 18, + 5 + ], + [ + 22, + 7 + ], + [ + 8, + 4 + ], + [ + 16, + 10 + ], + [ + 25, + 5 + ], + [ + 23, + 7 + ], + [ + 2, + 3 + ], + [ + 6, + 12 + ], + [ + 18, + 3 + ], + [ + 19, + 5 + ], + [ + 22, + 9 + ], + [ + 23, + 14 + ], + [ + 24, + 5 + ], + [ + 24, + 14 + ], + [ + 23, + 4 + ], + [ + 24, + 15 + ], + [ + 22, + 12 + ], + [ + 11, + 11 + ], + [ + 3, + 11 + ], + [ + 26, + 8 + ], + [ + 13, + 0 + ], + [ + 23, + 15 + ], + [ + 25, + 11 + ], + [ + 8, + 1 + ], + [ + 24, + 7 + ], + [ + 25, + 4 + ], + [ + 13, + 14 + ], + [ + 25, + 7 + ], + [ + 22, + 8 + ], + [ + 24, + 13 + ], + [ + 25, + 6 + ], + [ + 25, + 12 + ], + [ + 24, + 12 + ], + [ + 20, + 15 + ], + [ + 24, + 11 + ], + [ + 25, + 13 + ], + [ + 26, + 2 + ], + [ + 20, + 11 + ], + [ + 25, + 10 + ], + [ + 25, + 14 + ], + [ + 26, + 7 + ], + [ + 12, + 6 + ], + [ + 26, + 14 + ], + [ + 16, + 11 + ], + [ + 24, + 10 + ], + [ + 26, + 4 + ], + [ + 26, + 3 + ], + [ + 21, + 12 + ], + [ + 19, + 3 + ], + [ + 22, + 13 + ], + [ + 24, + 6 + ], + [ + 1, + 15 + ], + [ + 26, + 5 + ], + [ + 21, + 0 + ], + [ + 23, + 1 + ], + [ + 20, + 7 + ], + [ + 21, + 1 + ], + [ + 13, + 7 + ], + [ + 19, + 9 + ], + [ + 20, + 10 + ], + [ + 26, + 15 + ], + [ + 9, + 12 + ], + [ + 23, + 0 + ], + [ + 25, + 9 + ], + [ + 13, + 6 + ], + [ + 16, + 7 + ], + [ + 5, + 10 + ], + [ + 19, + 4 + ], + [ + 4, + 12 + ], + [ + 19, + 8 + ], + [ + 24, + 4 + ], + [ + 13, + 15 + ], + [ + 14, + 12 + ], + [ + 23, + 5 + ], + [ + 9, + 13 + ], + [ + 21, + 13 + ], + [ + 16, + 13 + ], + [ + 17, + 13 + ], + [ + 19, + 13 + ], + [ + 13, + 4 + ], + [ + 4, + 10 + ], + [ + 18, + 10 + ], + [ + 26, + 0 + ], + [ + 18, + 2 + ], + [ + 12, + 1 + ], + [ + 2, + 14 + ], + [ + 19, + 12 + ], + [ + 20, + 6 + ], + [ + 21, + 15 + ], + [ + 8, + 5 + ], + [ + 25, + 15 + ], + [ + 23, + 10 + ], + [ + 21, + 11 + ], + [ + 25, + 8 + ], + [ + 1, + 8 + ], + [ + 2, + 8 + ], + [ + 20, + 4 + ], + [ + 12, + 0 + ], + [ + 17, + 8 + ], + [ + 20, + 12 + ], + [ + 5, + 15 + ], + [ + 10, + 5 + ], + [ + 26, + 1 + ], + [ + 23, + 11 + ], + [ + 12, + 8 + ], + [ + 20, + 8 + ], + [ + 21, + 14 + ], + [ + 8, + 14 + ], + [ + 15, + 10 + ], + [ + 19, + 2 + ], + [ + 16, + 12 + ], + [ + 22, + 6 + ], + [ + 8, + 7 + ], + [ + 7, + 4 + ], + [ + 22, + 4 + ], + [ + 11, + 10 + ], + [ + 18, + 4 + ], + [ + 3, + 13 + ], + [ + 22, + 1 + ], + [ + 16, + 9 + ], + [ + 4, + 15 + ], + [ + 17, + 1 + ], + [ + 17, + 0 + ], + [ + 2, + 2 + ], + [ + 11, + 7 + ], + [ + 22, + 11 + ], + [ + 3, + 4 + ], + [ + 12, + 3 + ], + [ + 16, + 8 + ], + [ + 6, + 0 + ], + [ + 10, + 3 + ], + [ + 21, + 10 + ], + [ + 22, + 5 + ], + [ + 22, + 10 + ], + [ + 6, + 9 + ], + [ + 23, + 6 + ], + [ + 17, + 12 + ], + [ + 17, + 9 + ], + [ + 21, + 6 + ], + [ + 21, + 4 + ], + [ + 17, + 10 + ], + [ + 12, + 10 + ], + [ + 8, + 6 + ], + [ + 9, + 9 + ], + [ + 14, + 11 + ], + [ + 20, + 13 + ], + [ + 22, + 14 + ], + [ + 17, + 2 + ], + [ + 17, + 7 + ], + [ + 9, + 8 + ], + [ + 7, + 13 + ], + [ + 14, + 5 + ], + [ + 5, + 5 + ], + [ + 11, + 1 + ], + [ + 15, + 6 + ], + [ + 26, + 9 + ], + [ + 6, + 1 + ], + [ + 12, + 14 + ], + [ + 22, + 0 + ], + [ + 14, + 13 + ], + [ + 4, + 11 + ], + [ + 26, + 11 + ], + [ + 3, + 15 + ], + [ + 12, + 15 + ], + [ + 3, + 14 + ], + [ + 23, + 12 + ], + [ + 18, + 15 + ], + [ + 8, + 13 + ], + [ + 18, + 11 + ], + [ + 18, + 13 + ], + [ + 9, + 5 + ], + [ + 14, + 0 + ], + [ + 2, + 6 + ], + [ + 13, + 1 + ], + [ + 9, + 0 + ], + [ + 11, + 8 + ], + [ + 9, + 4 + ], + [ + 21, + 7 + ], + [ + 14, + 10 + ], + [ + 6, + 14 + ], + [ + 14, + 6 + ], + [ + 11, + 0 + ], + [ + 17, + 11 + ], + [ + 3, + 3 + ], + [ + 15, + 4 + ], + [ + 15, + 14 + ], + [ + 10, + 2 + ], + [ + 18, + 6 + ], + [ + 17, + 3 + ], + [ + 11, + 6 + ], + [ + 13, + 2 + ], + [ + 10, + 4 + ], + [ + 15, + 2 + ], + [ + 15, + 5 + ], + [ + 8, + 10 + ], + [ + 9, + 10 + ], + [ + 7, + 15 + ], + [ + 1, + 13 + ], + [ + 6, + 4 + ], + [ + 12, + 12 + ], + [ + 19, + 1 + ], + [ + 19, + 6 + ], + [ + 12, + 2 + ], + [ + 19, + 0 + ], + [ + 26, + 6 + ], + [ + 11, + 5 + ], + [ + 15, + 1 + ], + [ + 9, + 7 + ], + [ + 15, + 9 + ], + [ + 3, + 2 + ], + [ + 7, + 6 + ], + [ + 16, + 6 + ], + [ + 14, + 2 + ], + [ + 15, + 11 + ], + [ + 18, + 7 + ], + [ + 20, + 5 + ], + [ + 9, + 2 + ], + [ + 23, + 9 + ], + [ + 16, + 3 + ], + [ + 20, + 2 + ], + [ + 1, + 0 + ], + [ + 10, + 11 + ], + [ + 1, + 1 + ], + [ + 21, + 5 + ], + [ + 27, + 6 + ], + [ + 5, + 14 + ], + [ + 10, + 9 + ], + [ + 15, + 3 + ], + [ + 6, + 15 + ], + [ + 7, + 11 + ], + [ + 8, + 2 + ], + [ + 10, + 13 + ], + [ + 8, + 11 + ], + [ + 14, + 7 + ], + [ + 10, + 8 + ], + [ + 23, + 13 + ], + [ + 0, + 11 + ], + [ + 1, + 3 + ], + [ + 6, + 5 + ], + [ + 5, + 4 + ], + [ + 7, + 5 + ], + [ + 8, + 12 + ], + [ + 13, + 5 + ], + [ + 7, + 0 + ], + [ + 16, + 5 + ], + [ + 0, + 10 + ], + [ + 1, + 14 + ], + [ + 7, + 10 + ], + [ + 8, + 9 + ], + [ + 26, + 12 + ], + [ + 5, + 9 + ], + [ + 9, + 14 + ], + [ + 5, + 8 + ], + [ + 13, + 3 + ], + [ + 3, + 1 + ], + [ + 5, + 7 + ], + [ + 7, + 14 + ], + [ + 15, + 0 + ], + [ + 3, + 12 + ], + [ + 6, + 3 + ], + [ + 7, + 12 + ], + [ + 2, + 15 + ], + [ + 4, + 8 + ], + [ + 5, + 11 + ], + [ + 8, + 0 + ], + [ + 12, + 5 + ], + [ + 4, + 4 + ], + [ + 5, + 6 + ], + [ + 3, + 0 + ], + [ + 3, + 9 + ], + [ + 25, + 1 + ], + [ + 4, + 2 + ], + [ + 17, + 14 + ], + [ + 9, + 11 + ], + [ + 16, + 4 + ], + [ + 15, + 15 + ], + [ + 3, + 8 + ], + [ + 11, + 9 + ], + [ + 16, + 2 + ], + [ + 18, + 8 + ], + [ + 11, + 4 + ], + [ + 11, + 15 + ], + [ + 17, + 15 + ], + [ + 21, + 2 + ], + [ + 2, + 13 + ], + [ + 10, + 12 + ], + [ + 15, + 7 + ], + [ + 12, + 13 + ], + [ + 20, + 1 + ], + [ + 0, + 3 + ], + [ + 5, + 0 + ], + [ + 4, + 3 + ], + [ + 22, + 2 + ], + [ + 8, + 8 + ], + [ + 0, + 0 + ], + [ + 26, + 13 + ], + [ + 27, + 10 + ], + [ + 0, + 2 + ], + [ + 5, + 1 + ], + [ + 10, + 15 + ], + [ + 20, + 0 + ], + [ + 27, + 1 + ] + ] +} \ No newline at end of file diff --git a/scripts/alignment_heads_qwen3_asr_1.7B.png b/scripts/alignment_heads_qwen3_asr_1.7B.png new file mode 100644 index 0000000..d564132 Binary files /dev/null and b/scripts/alignment_heads_qwen3_asr_1.7B.png differ diff --git a/scripts/detect_alignment_heads_qwen3.py b/scripts/detect_alignment_heads_qwen3.py new file mode 100644 index 0000000..6cb268e --- /dev/null +++ b/scripts/detect_alignment_heads_qwen3.py @@ -0,0 +1,703 @@ +#!/usr/bin/env python3 +""" +Detect alignment heads in Qwen3-ASR for SimulStreaming-style inference. + +Qwen3-ASR is a decoder-only multimodal model: audio is encoded by an audio +encoder and the resulting embeddings are injected into the text sequence +(replacing <|audio_pad|> placeholder tokens). The text decoder then attends +over the full sequence -- both audio-derived tokens and text tokens -- via +causal self-attention. There is **no** cross-attention. + +For AlignAtt-style streaming, we need to find which (layer, head) pairs in +the text decoder's self-attention best track the monotonic alignment between +generated text tokens and their corresponding audio positions. + +Algorithm +--------- +For each audio sample with a known transcript: + 1. Run Qwen3-ASR with output_attentions=True + 2. Use the ForcedAligner to get ground-truth word->timestamp alignments + 3. Convert timestamps to audio token positions in the input sequence + 4. For each generated text token, check whether the argmax of each + attention head (over the audio-token region) points to the correct + audio position (as determined by the forced aligner) + 5. Accumulate scores per (layer, head) + +The heads whose attention argmax matches the ground-truth alignment most +often are the "alignment heads" usable for SimulStreaming. + +Reference: Adapted from scripts/determine_alignment_heads.py (Whisper) and + iwslt26-sst/SimulMT_tests/heads/detect_translation_heads_qwen3.py +""" + +import argparse +import io +import json +import logging +import re +import time +from difflib import SequenceMatcher +from typing import List, Optional, Tuple + +import numpy as np +import soundfile as sf +import torch + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +# ── Compatibility patches for qwen_asr 0.0.6 + transformers >= 5.3 ──── +def _apply_transformers_compat_patches(): + """Apply all necessary patches to make qwen_asr work with transformers >= 5.3.""" + # 1. check_model_inputs was removed + try: + import transformers.utils.generic as _g + if not hasattr(_g, "check_model_inputs"): + def check_model_inputs(*args, **kwargs): + def decorator(fn): + return fn + return decorator + _g.check_model_inputs = check_model_inputs + except ImportError: + pass + + # 2. 'default' rope type was removed from ROPE_INIT_FUNCTIONS + try: + from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + if "default" not in ROPE_INIT_FUNCTIONS: + def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs): + if hasattr(config, "head_dim"): + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + partial = getattr(config, "partial_rotary_factor", 1.0) + dim = int(head_dim * partial) + base = config.rope_theta + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, 1.0 + ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters + except ImportError: + pass + + # 3. pad_token_id missing on thinker config + try: + from qwen_asr.core.transformers_backend.configuration_qwen3_asr import ( + Qwen3ASRThinkerConfig, + ) + if not hasattr(Qwen3ASRThinkerConfig, "pad_token_id"): + Qwen3ASRThinkerConfig.pad_token_id = None + except ImportError: + pass + + # 4. fix_mistral_regex is now handled internally by transformers 5.3; + # qwen_asr passes it explicitly, causing a duplicate-kwarg error. + try: + from transformers.models.auto import processing_auto + _orig_ap_from_pretrained = processing_auto.AutoProcessor.from_pretrained.__func__ + + @classmethod + def _patched_ap_from_pretrained(cls, *args, **kwargs): + kwargs.pop("fix_mistral_regex", None) + return _orig_ap_from_pretrained(cls, *args, **kwargs) + + processing_auto.AutoProcessor.from_pretrained = _patched_ap_from_pretrained + except Exception: + pass + + # 5. _finalize_model_loading calls initialize_weights which expects + # compute_default_rope_parameters on RotaryEmbedding modules. + try: + from qwen_asr.core.transformers_backend.modeling_qwen3_asr import ( + Qwen3ASRThinkerTextRotaryEmbedding, + ) + if not hasattr(Qwen3ASRThinkerTextRotaryEmbedding, "compute_default_rope_parameters"): + @staticmethod + def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs): + if hasattr(config, "head_dim"): + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + partial = getattr(config, "partial_rotary_factor", 1.0) + dim = int(head_dim * partial) + base = config.rope_theta + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, 1.0 + Qwen3ASRThinkerTextRotaryEmbedding.compute_default_rope_parameters = _compute_default_rope_parameters + except ImportError: + pass + +_apply_transformers_compat_patches() + +# ── Constants ──────────────────────────────────────────────────────── +SAMPLE_RATE = 16000 +TS_THRESHOLD = 0.1 # Minimum Translation Score to qualify as alignment head +MIN_TEXT_SIMILARITY = 0.3 # Skip clips where generated text is too different from ground truth + + +def text_similarity(generated: str, reference: str) -> float: + """Compute text similarity between generated and reference transcriptions. + + Normalizes both strings (lowercase, remove punctuation, collapse whitespace) + then returns SequenceMatcher ratio. + """ + def normalize(s): + s = s.lower() + s = re.sub(r'[^\w\s]', '', s) + return re.sub(r'\s+', ' ', s).strip() + + gen_norm = normalize(generated) + ref_norm = normalize(reference) + if not gen_norm or not ref_norm: + return 0.0 + return SequenceMatcher(None, gen_norm, ref_norm).ratio() + + +def load_dataset_clips(name, config, split, limit): + """Load audio clips from a HuggingFace dataset.""" + from datasets import Audio as DatasetAudio + from datasets import load_dataset + + ds = load_dataset(name, config, split=split) + ds = ds.cast_column("audio", DatasetAudio(decode=False)) + clips = [] + for idx, row in enumerate(ds): + if limit is not None and idx >= limit: + break + audio_field = row["audio"] + transcript = row["text"] + + waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32") + if waveform_np.ndim > 1: + waveform_np = waveform_np.mean(axis=1) + + clips.append((waveform_np, str(transcript))) + return clips + + +def get_device(): + """Select the best available device.""" + if torch.backends.mps.is_available(): + logger.info("Using MPS (Apple Silicon GPU)") + return torch.device("mps") + elif torch.cuda.is_available(): + logger.info("Using CUDA (%s)", torch.cuda.get_device_name()) + return torch.device("cuda") + else: + logger.info("Using CPU (will be slow)") + return torch.device("cpu") + + +def load_qwen3_asr(model_id: str, device: torch.device, dtype: torch.dtype): + """Load Qwen3-ASR model, processor, and forced aligner.""" + from qwen_asr.core.transformers_backend import ( + Qwen3ASRConfig, + Qwen3ASRForConditionalGeneration, + Qwen3ASRProcessor, + ) + from qwen_asr.inference.qwen3_forced_aligner import Qwen3ForcedAligner + from transformers import AutoConfig, AutoModel, AutoProcessor + + AutoConfig.register("qwen3_asr", Qwen3ASRConfig) + AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration) + AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor) + + logger.info("Loading model: %s (dtype=%s, device=%s)", model_id, dtype, device) + model = AutoModel.from_pretrained( + model_id, + torch_dtype=dtype, + attn_implementation="eager", + device_map=str(device), + ) + model.eval() + + # Force eager attention on all sub-modules (attn_implementation="eager" doesn't + # propagate through nested model configs in qwen_asr's custom architecture) + for name, module in model.named_modules(): + if hasattr(module, "config") and hasattr(module.config, "_attn_implementation"): + module.config._attn_implementation = "eager" + module.config._attn_implementation_internal = "eager" + + try: + processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True) + except TypeError: + processor = AutoProcessor.from_pretrained(model_id) + + logger.info("Loading forced aligner: Qwen/Qwen3-ForcedAligner-0.6B") + forced_aligner = Qwen3ForcedAligner.from_pretrained( + "Qwen/Qwen3-ForcedAligner-0.6B", + dtype=dtype, + device_map=str(device), + ) + + return model, processor, forced_aligner + + +def find_audio_token_range(input_ids: torch.Tensor, audio_token_id: int) -> Tuple[int, int]: + """Find the start and end positions of audio tokens in the input sequence.""" + mask = (input_ids == audio_token_id) + positions = mask.nonzero(as_tuple=True)[0] + if len(positions) == 0: + return 0, 0 + return positions[0].item(), positions[-1].item() + 1 + + +def timestamp_to_audio_token_position( + timestamp_sec: float, + audio_duration_sec: float, + audio_token_start: int, + audio_token_end: int, +) -> int: + """Convert a timestamp in seconds to the corresponding audio token position. + + Audio tokens span [audio_token_start, audio_token_end) in the input sequence. + We linearly interpolate within that range based on the timestamp fraction. + """ + n_audio_tokens = audio_token_end - audio_token_start + if n_audio_tokens <= 0 or audio_duration_sec <= 0: + return audio_token_start + + fraction = min(timestamp_sec / audio_duration_sec, 1.0) + pos = audio_token_start + int(fraction * (n_audio_tokens - 1)) + return max(audio_token_start, min(pos, audio_token_end - 1)) + + +def run_detection( + model, + processor, + forced_aligner, + clips: List[Tuple[np.ndarray, str]], + language: Optional[str], + device: torch.device, +) -> Tuple[np.ndarray, int]: + """Run alignment head detection on a set of audio clips. + + Uses PyTorch forward hooks on each self_attn module to capture attention + weights that the decoder layer discards (``hidden_states, _ = self.self_attn(...)``). + With eager attention, ``self_attn`` always returns ``(attn_output, attn_weights)`` + so the hook can read the weights from the return value. + + Returns: + g: array of shape (total_heads,) with alignment hit counts + m: total number of alignment checks performed + """ + thinker = model.thinker + text_config = thinker.config.text_config + num_layers = text_config.num_hidden_layers + num_heads = text_config.num_attention_heads + total_heads = num_layers * num_heads + + audio_token_id = thinker.config.audio_token_id + + logger.info( + "Text decoder: %d layers x %d heads = %d total heads", + num_layers, num_heads, total_heads, + ) + logger.info( + "KV heads: %d (GQA ratio: %d)", + text_config.num_key_value_heads, + num_heads // text_config.num_key_value_heads, + ) + + # Build prompt helper (same as Qwen3ASRModel._build_text_prompt) + from qwen_asr.inference.utils import normalize_language_name + + def build_messages(audio_payload): + return [ + {"role": "system", "content": ""}, + {"role": "user", "content": [{"type": "audio", "audio": audio_payload}]}, + ] + + def build_text_prompt(force_language=None): + msgs = build_messages("") + base = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + if force_language: + base = base + f"language {force_language}" + return base + + force_lang = None + if language: + force_lang = normalize_language_name(language) + + # Stop token IDs + eos_ids = {151645, 151643} # <|im_end|>, <|endoftext|> + if processor.tokenizer.eos_token_id is not None: + eos_ids.add(processor.tokenizer.eos_token_id) + + # Decoder layers: model.thinker.model.layers[i].self_attn + decoder_layers = thinker.model.layers + + g = np.zeros(total_heads, dtype=np.int64) + m = 0 + t0 = time.time() + + for clip_idx, (waveform, transcript) in enumerate(clips): + if not transcript.strip(): + continue + + audio_duration = len(waveform) / SAMPLE_RATE + + # 1. Get forced alignment timestamps + try: + align_results = forced_aligner.align( + audio=[(waveform, SAMPLE_RATE)], + text=[transcript], + language=[force_lang or "English"], + ) + align_result = align_results[0] + except Exception as e: + logger.warning("Forced alignment failed for clip %d: %s", clip_idx, e) + continue + + if not align_result.items: + continue + + # Build word -> (start_time, end_time) mapping + word_timestamps = [] + for item in align_result.items: + word_timestamps.append((item.text, item.start_time, item.end_time)) + + # 2. Prepare inputs + text_prompt = build_text_prompt(force_language=force_lang) + inputs = processor( + text=[text_prompt], + audio=[waveform], + return_tensors="pt", + padding=True, + ) + inputs = inputs.to(model.device).to(model.dtype) + prompt_len = inputs.input_ids.shape[1] + + # Find audio token range + audio_start, audio_end = find_audio_token_range( + inputs.input_ids[0], audio_token_id, + ) + n_audio_tokens = audio_end - audio_start + + if n_audio_tokens == 0: + logger.warning("No audio tokens found in clip %d", clip_idx) + continue + + # 3. Register forward hooks on self_attn to capture attention weights. + # The decoder layer discards them: hidden_states, _ = self.self_attn(...) + # but eager_attention_forward always computes and returns attn_weights. + # We capture just the argmax over the audio region (memory-efficient). + # captured_argmax[layer_idx] = list of (num_heads,) tensors, one per decode step. + captured_argmax = {i: [] for i in range(num_layers)} + + def _make_hook(store, a_start, a_end): + def hook_fn(module, args, output): + # output = (attn_output, attn_weights) + attn_weights = output[1] + if attn_weights is None: + return + # attn_weights shape: (batch, num_heads, q_len, kv_len) + # Only capture decode steps (q_len == 1), skip prefill + if attn_weights.shape[2] != 1: + return + kv_len = attn_weights.shape[-1] + if a_end > kv_len: + return + # Attention from the new token over audio region + audio_attn = attn_weights[0, :, 0, a_start:a_end] # (num_heads, n_audio) + store.append(audio_attn.argmax(dim=-1).cpu()) # (num_heads,) + return hook_fn + + hooks = [] + for layer_idx in range(num_layers): + h = decoder_layers[layer_idx].self_attn.register_forward_hook( + _make_hook(captured_argmax[layer_idx], audio_start, audio_end) + ) + hooks.append(h) + + # 4. Run generation + try: + with torch.inference_mode(): + outputs = thinker.generate( + **inputs, + max_new_tokens=256, + do_sample=False, + ) + except Exception as e: + for h in hooks: + h.remove() + logger.warning("Generation failed for clip %d: %s", clip_idx, e) + continue + finally: + for h in hooks: + h.remove() + + # outputs is (batch, seq_len) tensor + all_generated = outputs[0, prompt_len:] + num_gen = len(all_generated) + for i, tid in enumerate(all_generated): + if tid.item() in eos_ids: + num_gen = i + break + generated_ids = all_generated[:num_gen] + + if num_gen == 0: + del outputs, captured_argmax + continue + + generated_text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True) + + # Filter out hallucinated clips (e.g. "!!!" patterns) + sim = text_similarity(generated_text, transcript) + if sim < MIN_TEXT_SIMILARITY: + logger.info( + "[%d/%d] SKIP (sim=%.2f) | %s...", + clip_idx + 1, len(clips), sim, generated_text[:60], + ) + del outputs, captured_argmax + continue + + # Verify hooks captured data + n_captured = len(captured_argmax[0]) + if n_captured == 0: + logger.warning( + "No attention weights captured for clip %d (hooks may not have fired)", clip_idx + ) + del outputs, captured_argmax + continue + + # 5. Map generated tokens to word timestamps + gen_token_strings = [ + processor.tokenizer.decode([tid.item()]) for tid in generated_ids + ] + + # Map each generated token index -> forced-aligner word index + accumulated_text = "" + word_idx = 0 + token_to_word = {} + for tok_idx, tok_str in enumerate(gen_token_strings): + accumulated_text += tok_str + # Advance word index when accumulated text covers the current word + while ( + word_idx < len(word_timestamps) + and len(accumulated_text.strip()) >= sum( + len(w[0]) + 1 for w in word_timestamps[:word_idx + 1] + ) + ): + word_idx += 1 + actual_word_idx = min(word_idx, len(word_timestamps) - 1) + token_to_word[tok_idx] = actual_word_idx + + # 6. Score each head using captured argmax data + for gen_step in range(num_gen): + word_idx = token_to_word.get(gen_step, None) + if word_idx is None or word_idx >= len(word_timestamps): + continue + + _, word_start, word_end = word_timestamps[word_idx] + word_mid = (word_start + word_end) / 2.0 + + # Expected audio token position for this word + expected_pos = timestamp_to_audio_token_position( + word_mid, audio_duration, audio_start, audio_end, + ) + + # Tolerance: +/- a few audio tokens (proportional to word duration) + word_dur_tokens = max(1, int( + (word_end - word_start) / audio_duration * n_audio_tokens / 2 + )) + tolerance = max(3, word_dur_tokens) + + m += 1 + + for layer_idx in range(num_layers): + if gen_step >= len(captured_argmax[layer_idx]): + continue + argmaxes = captured_argmax[layer_idx][gen_step].numpy() # (num_heads,) + + for head_idx in range(num_heads): + attended_pos = argmaxes[head_idx] # relative to audio_start + attended_abs = audio_start + attended_pos + if abs(attended_abs - expected_pos) <= tolerance: + g[layer_idx * num_heads + head_idx] += 1 + + del outputs, captured_argmax + if device.type == "mps": + torch.mps.empty_cache() + elif device.type == "cuda": + torch.cuda.empty_cache() + + elapsed = time.time() - t0 + avg = elapsed / (clip_idx + 1) + eta = avg * (len(clips) - clip_idx - 1) + logger.info( + "[%d/%d] m=%d | %s... | %.1fs/clip | ETA: %.0fs", + clip_idx + 1, len(clips), m, + generated_text[:60], avg, eta, + ) + + return g, m + + +def main(): + parser = argparse.ArgumentParser( + description="Detect alignment heads in Qwen3-ASR for SimulStreaming" + ) + parser.add_argument( + "--model", type=str, default="Qwen/Qwen3-ASR-1.7B", + help="Qwen3-ASR model name or path", + ) + parser.add_argument( + "--dataset", type=str, default="librispeech_asr", + help="HuggingFace dataset name", + ) + parser.add_argument( + "--dataset-config", type=str, default="clean", + help="Dataset config/subset", + ) + parser.add_argument( + "--dataset-split", type=str, default="validation", + help="Dataset split", + ) + parser.add_argument( + "-n", "--num-samples", type=int, default=50, + help="Number of audio samples to process", + ) + parser.add_argument( + "--language", type=str, default="English", + help="Language for forced alignment", + ) + parser.add_argument( + "--dtype", type=str, default="bf16", + choices=["float32", "bf16", "float16"], + help="Model dtype", + ) + parser.add_argument( + "-o", "--output", type=str, default="alignment_heads_qwen3_asr.json", + help="Output JSON file", + ) + parser.add_argument( + "--heatmap", type=str, default="alignment_heads_qwen3_asr.png", + help="Output heatmap image", + ) + parser.add_argument( + "--threshold", type=float, default=TS_THRESHOLD, + help="Minimum alignment score threshold", + ) + args = parser.parse_args() + + device = get_device() + + dtype_map = { + "float32": torch.float32, + "bf16": torch.bfloat16, + "float16": torch.float16, + } + dtype = dtype_map[args.dtype] + + # Load model + model, processor, forced_aligner = load_qwen3_asr(args.model, device, dtype) + + # Load data + logger.info("Loading dataset: %s/%s [%s]", args.dataset, args.dataset_config, args.dataset_split) + clips = load_dataset_clips( + args.dataset, args.dataset_config, args.dataset_split, args.num_samples, + ) + logger.info("Loaded %d clips", len(clips)) + + # Run detection + g, m = run_detection(model, processor, forced_aligner, clips, args.language, device) + + # Compute alignment scores + thinker = model.thinker + text_config = thinker.config.text_config + num_layers = text_config.num_hidden_layers + num_heads = text_config.num_attention_heads + + ts = g / max(m, 1) + ts_matrix = ts.reshape(num_layers, num_heads) + + # Identify alignment heads + tah = [] + for l in range(num_layers): + for h in range(num_heads): + score = ts_matrix[l, h] + if score > args.threshold: + tah.append({"layer": l, "head": h, "ts": round(float(score), 4)}) + + tah.sort(key=lambda x: x["ts"], reverse=True) + + # Print results + print(f"\n{'=' * 60}") + print(f"ALIGNMENT HEADS (TS > {args.threshold}): {len(tah)} / {num_layers * num_heads}") + print(f"{'=' * 60}") + for entry in tah: + bar = "#" * int(entry["ts"] * 50) + print(f" L{entry['layer']:2d} H{entry['head']:2d} : TS={entry['ts']:.4f} {bar}") + + n_active = sum(1 for s in ts if s > args.threshold) + n_low = sum(1 for s in ts if 0 < s <= args.threshold) + n_zero = sum(1 for s in ts if s == 0) + total_heads = num_layers * num_heads + print(f"\nDistribution:") + print(f" TS > {args.threshold} (alignment heads): {n_active} ({100 * n_active / total_heads:.1f}%)") + print(f" 0 < TS <= {args.threshold} (low activity): {n_low} ({100 * n_low / total_heads:.1f}%)") + print(f" TS = 0 (inactive): {n_zero} ({100 * n_zero / total_heads:.1f}%)") + print(f"\nTotal alignable tokens checked: m={m}") + + # Save JSON + output = { + "model": args.model, + "language": args.language, + "num_layers": num_layers, + "num_heads": num_heads, + "num_kv_heads": text_config.num_key_value_heads, + "num_samples": len(clips), + "total_alignable_tokens": int(m), + "ts_threshold": args.threshold, + "ts_matrix": ts_matrix.tolist(), + "alignment_heads": tah, + # WhisperLiveKit-compatible format: list of [layer, head] pairs + "alignment_heads_compact": [[e["layer"], e["head"]] for e in tah], + } + with open(args.output, "w") as f: + json.dump(output, f, indent=2) + logger.info("Results saved to %s", args.output) + + # Generate heatmap + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + fig, ax = plt.subplots( + figsize=(max(10, num_heads * 0.6), max(8, num_layers * 0.35)), + ) + im = ax.imshow( + ts_matrix, + aspect="auto", + cmap="RdYlBu_r", + vmin=0, + vmax=max(0.4, ts_matrix.max()), + interpolation="nearest", + ) + ax.set_xlabel("Head ID", fontsize=12) + ax.set_ylabel("Layer", fontsize=12) + ax.set_title( + f"Alignment Scores - {args.model}\n" + f"{len(tah)} alignment heads (TS > {args.threshold}), n={len(clips)}", + fontsize=13, + ) + ax.set_xticks(range(num_heads)) + ax.set_yticks(range(num_layers)) + plt.colorbar(im, ax=ax, label="Alignment Score", shrink=0.8) + + for entry in tah: + ax.add_patch(plt.Rectangle( + (entry["head"] - 0.5, entry["layer"] - 0.5), + 1, 1, fill=False, edgecolor="red", linewidth=1.5, + )) + + plt.tight_layout() + plt.savefig(args.heatmap, dpi=150) + logger.info("Heatmap saved to %s", args.heatmap) + except Exception as e: + logger.warning("Could not generate heatmap: %s", e) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_architecture.py b/scripts/generate_architecture.py new file mode 100644 index 0000000..7f42d45 --- /dev/null +++ b/scripts/generate_architecture.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +"""Generate the architecture.png diagram for WhisperLiveKit README.""" + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.patches import FancyBboxPatch, FancyArrowPatch + +# ── Colours ── +C_BG = "#1a1a2e" +C_PANEL = "#16213e" +C_PANEL2 = "#0f3460" +C_ACCENT = "#e94560" +C_GREEN = "#4ecca3" +C_ORANGE = "#f5a623" +C_BLUE = "#4a9eff" +C_PURPLE = "#b06af2" +C_PINK = "#ff6b9d" +C_YELLOW = "#f0e68c" +C_TEXT = "#e8e8e8" +C_TEXTDIM = "#a0a0b0" +C_BOX_BG = "#1e2d4a" +C_BOX_BG2 = "#2a1a3a" +C_BOX_BG3 = "#1a3a2a" +C_BORDER = "#3a4a6a" + +fig, ax = plt.subplots(1, 1, figsize=(20, 12), facecolor=C_BG) +ax.set_xlim(0, 20) +ax.set_ylim(0, 12) +ax.set_aspect("equal") +ax.axis("off") +fig.subplots_adjust(left=0.01, right=0.99, top=0.97, bottom=0.01) + + +def box(x, y, w, h, label, color=C_BORDER, bg=C_BOX_BG, fontsize=8, bold=False, + text_color=C_TEXT, radius=0.15): + rect = FancyBboxPatch( + (x, y), w, h, + boxstyle=f"round,pad=0.05,rounding_size={radius}", + facecolor=bg, edgecolor=color, linewidth=1.2, + ) + ax.add_patch(rect) + weight = "bold" if bold else "normal" + ax.text(x + w/2, y + h/2, label, ha="center", va="center", + fontsize=fontsize, color=text_color, fontweight=weight, family="monospace") + return rect + + +def arrow(x1, y1, x2, y2, color=C_TEXTDIM, style="->", lw=1.2): + ax.annotate("", xy=(x2, y2), xytext=(x1, y1), + arrowprops=dict(arrowstyle=style, color=color, lw=lw)) + + +def section_box(x, y, w, h, title, bg=C_PANEL, border=C_BORDER, title_color=C_ACCENT): + rect = FancyBboxPatch( + (x, y), w, h, + boxstyle="round,pad=0.05,rounding_size=0.2", + facecolor=bg, edgecolor=border, linewidth=1.5, + ) + ax.add_patch(rect) + ax.text(x + 0.15, y + h - 0.25, title, ha="left", va="top", + fontsize=9, color=title_color, fontweight="bold", family="monospace") + + +# ═══════════════════════════════════════════════════════════════════ +# Title +# ═══════════════════════════════════════════════════════════════════ +ax.text(10, 11.7, "WhisperLiveKit Architecture", ha="center", va="center", + fontsize=16, color=C_TEXT, fontweight="bold", family="monospace") +ax.text(10, 11.35, "CLI commands: serve | listen | run | transcribe | bench | diagnose | models | pull | rm | check", + ha="center", va="center", fontsize=7, color=C_TEXTDIM, family="monospace") + +# ═══════════════════════════════════════════════════════════════════ +# Left: Client / Server +# ═══════════════════════════════════════════════════════════════════ +section_box(0.1, 7.0, 3.5, 4.0, "FastAPI Server", border=C_GREEN) + +box(0.3, 10.0, 1.5, 0.5, "Web UI\nHTML + JS", color=C_GREEN, fontsize=7) +box(2.0, 10.0, 1.4, 0.5, "Frontend\n(optional)", color=C_GREEN, fontsize=7) + +box(0.3, 9.1, 3.1, 0.6, "WebSocket /asr • /v1/listen", color=C_GREEN, fontsize=7, bold=True) +box(0.3, 8.3, 3.1, 0.5, "REST /v1/audio/transcriptions", color=C_GREEN, fontsize=7) +box(0.3, 7.4, 3.1, 0.5, "Health • /v1/models", color=C_GREEN, fontsize=7) + +# Clients +ax.text(0.2, 6.5, "Clients:", fontsize=7, color=C_TEXTDIM, family="monospace") +for i, client in enumerate(["Browser", "OpenAI SDK", "Deepgram SDK", "TestHarness"]): + box(0.3 + i * 0.9, 5.8, 0.8, 0.5, client, fontsize=5.5, bg="#1a2a1a", color="#3a6a3a") + +# ═══════════════════════════════════════════════════════════════════ +# Centre: Audio Processor (per-session pipeline) +# ═══════════════════════════════════════════════════════════════════ +section_box(4.0, 5.5, 5.5, 5.5, "Audio Processor (per session)", border=C_BLUE) + +box(4.3, 10.0, 2.0, 0.6, "FFmpeg\nDecoding", color=C_BLUE, bg="#1a2a4a", bold=True) +arrow(3.6, 9.4, 4.3, 10.2, color=C_GREEN) + +box(6.6, 10.0, 2.6, 0.6, "Silero VAD\nspeech / silence", color=C_BLUE, bg="#1a2a4a") +arrow(6.3, 10.3, 6.6, 10.3, color=C_BLUE) + +box(4.3, 8.8, 4.9, 0.8, "SessionASRProxy\nthread-safe per-session language override", color=C_BLUE, fontsize=7) +arrow(6.0, 10.0, 6.0, 9.6, color=C_BLUE) + +box(4.3, 7.6, 2.3, 0.8, "DiffTracker\n(opt-in ?mode=diff)", color="#5a5a7a", fontsize=7) +box(6.9, 7.6, 2.3, 0.8, "Result Formatter\n→ FrontData.to_dict()", color=C_BLUE, fontsize=7) + +# Streaming policies +ax.text(4.3, 7.1, "Streaming policies:", fontsize=7, color=C_ORANGE, fontweight="bold", family="monospace") +box(4.3, 6.2, 2.3, 0.7, "LocalAgreement\nHypothesisBuffer", color=C_ORANGE, bg="#2a2a1a", fontsize=7) +box(6.9, 6.2, 2.3, 0.7, "SimulStreaming\nAlignAtt (Whisper)", color=C_ORANGE, bg="#2a2a1a", fontsize=7) + +# ═══════════════════════════════════════════════════════════════════ +# Right: TranscriptionEngine (singleton) +# ═══════════════════════════════════════════════════════════════════ +section_box(10.0, 0.3, 9.8, 10.7, "TranscriptionEngine (singleton — shared across sessions)", + border=C_ACCENT, bg="#1e1520") + +ax.text(10.2, 10.5, "6 ASR Backends", fontsize=9, color=C_ACCENT, fontweight="bold", family="monospace") + +# ── Whisper backends ── +section_box(10.2, 7.3, 4.5, 3.0, "Whisper Family (chunk-based)", border=C_PURPLE, bg=C_BOX_BG2) + +box(10.4, 9.2, 1.3, 0.6, "Faster\nWhisper", color=C_PURPLE, bg="#2a1a3a", fontsize=7, bold=True) +box(11.9, 9.2, 1.3, 0.6, "MLX\nWhisper", color=C_PURPLE, bg="#2a1a3a", fontsize=7, bold=True) +box(13.4, 9.2, 1.1, 0.6, "OpenAI\nWhisper", color=C_PURPLE, bg="#2a1a3a", fontsize=7) + +ax.text(10.4, 8.7, "PCM → Encoder → Decoder → Tokens", fontsize=6.5, color=C_TEXTDIM, family="monospace") +ax.text(10.4, 8.3, "Uses LocalAgreement or SimulStreaming (AlignAtt)", fontsize=6, color=C_PURPLE, family="monospace") +ax.text(10.4, 7.9, "Language detection • Buffer trimming", fontsize=6, color=C_TEXTDIM, family="monospace") +ax.text(10.4, 7.5, "CPU / CUDA / MLX", fontsize=6, color=C_TEXTDIM, family="monospace") + +# ── Voxtral backends ── +section_box(10.2, 3.8, 4.5, 3.2, "Voxtral (native streaming)", border=C_PINK, bg="#2a1520") + +box(10.4, 5.9, 1.8, 0.6, "Voxtral MLX\n(Apple Silicon)", color=C_PINK, bg="#2a1520", fontsize=7, bold=True) +box(12.5, 5.9, 2.0, 0.6, "Voxtral HF\n(CUDA/MPS/CPU)", color=C_PINK, bg="#2a1520", fontsize=7, bold=True) + +ax.text(10.4, 5.4, "Incremental encoder → Autoregressive decoder", fontsize=6.5, color=C_TEXTDIM, family="monospace") +ax.text(10.4, 5.0, "Sliding KV cache • Token-by-token output", fontsize=6, color=C_PINK, family="monospace") +ax.text(10.4, 4.6, "No chunking needed — truly streams audio", fontsize=6, color=C_TEXTDIM, family="monospace") +ax.text(10.4, 4.2, "4B params • 15 languages • 6-bit quant (MLX)", fontsize=6, color=C_TEXTDIM, family="monospace") + +# ── Qwen3 backend ── +section_box(15.0, 3.8, 4.6, 3.2, "Qwen3 ASR (batch + aligner)", border=C_GREEN, bg=C_BOX_BG3) + +box(15.2, 5.9, 2.0, 0.6, "Qwen3 ASR\n1.7B / 0.6B", color=C_GREEN, bg="#1a3a2a", fontsize=7, bold=True) +box(17.4, 5.9, 2.0, 0.6, "Forced\nAligner", color=C_GREEN, bg="#1a3a2a", fontsize=7) + +ax.text(15.2, 5.4, "Full-audio batch inference", fontsize=6.5, color=C_TEXTDIM, family="monospace") +ax.text(15.2, 5.0, "ForcedAligner provides word timestamps", fontsize=6, color=C_GREEN, family="monospace") +ax.text(15.2, 4.6, "Uses LocalAgreement for streaming output", fontsize=6, color=C_TEXTDIM, family="monospace") +ax.text(15.2, 4.2, "12 languages • CUDA/MPS/CPU", fontsize=6, color=C_TEXTDIM, family="monospace") + +# ── OpenAI API ── +box(15.2, 7.7, 4.2, 0.6, "OpenAI API (cloud)", color="#5a6a7a", fontsize=7) +ax.text(15.2, 7.4, "Remote transcription • API key required", fontsize=6, color=C_TEXTDIM, family="monospace") + +# ── Shared components ── +section_box(10.2, 0.5, 9.4, 3.0, "Shared Components", border="#5a6a7a", bg="#151520") + +box(10.4, 2.2, 2.5, 0.8, "Mel Spectrogram\ncached DFT + filterbank", + color="#5a6a7a", fontsize=7) +box(13.2, 2.2, 2.5, 0.8, "Diarization\nSortformer / pyannote", + color="#5a6a7a", fontsize=7) +box(16.0, 2.2, 3.4, 0.8, "Translation\nNLLB • CTranslate2", + color="#5a6a7a", fontsize=7) + +box(10.4, 0.8, 4.0, 0.8, "WhisperLiveKitConfig\n(single source of truth)", + color=C_ACCENT, fontsize=7, bold=True) +box(14.8, 0.8, 4.6, 0.8, "TestHarness\nfull pipeline testing without server", + color="#5a6a7a", fontsize=7) + +# ═══════════════════════════════════════════════════════════════════ +# Arrows: main data flow +# ═══════════════════════════════════════════════════════════════════ + +# Audio processor → TranscriptionEngine +arrow(9.5, 8.5, 10.2, 8.5, color=C_ACCENT, lw=2) +ax.text(9.6, 8.8, "PCM audio", fontsize=6, color=C_ACCENT, family="monospace") + +# TranscriptionEngine → Audio processor (results) +arrow(10.2, 7.0, 9.5, 7.0, color=C_GREEN, lw=2) +ax.text(9.6, 7.3, "ASRTokens", fontsize=6, color=C_GREEN, family="monospace") + +# Streaming policy connections +arrow(5.5, 6.2, 5.5, 5.5, color=C_ORANGE, style="->") +arrow(8.1, 6.2, 8.1, 5.5, color=C_ORANGE, style="->") +ax.text(4.3, 5.6, "Whisper + Qwen3", fontsize=5.5, color=C_ORANGE, family="monospace") +ax.text(6.9, 5.6, "Whisper + Qwen3-simul", fontsize=5.5, color=C_ORANGE, family="monospace") + +# Voxtral note (no policy needed) +ax.text(10.2, 3.5, "Voxtral: own streaming processor (no external policy)", fontsize=6, + color=C_PINK, family="monospace", style="italic") + + +# ═══════════════════════════════════════════════════════════════════ +# Legend +# ═══════════════════════════════════════════════════════════════════ +legend_y = 5.0 +ax.text(0.3, legend_y, "Streaming modes:", fontsize=7, color=C_TEXT, fontweight="bold", family="monospace") +for i, (label, color) in enumerate([ + ("Native streaming (Voxtral)", C_PINK), + ("Chunk-based (Whisper)", C_PURPLE), + ("Batch + aligner (Qwen3)", C_GREEN), +]): + ax.plot([0.3], [legend_y - 0.4 - i * 0.35], "s", color=color, markersize=6) + ax.text(0.6, legend_y - 0.4 - i * 0.35, label, fontsize=6.5, color=color, + va="center", family="monospace") + + +plt.savefig("architecture.png", dpi=200, facecolor=C_BG, bbox_inches="tight", pad_inches=0.1) +print("Saved architecture.png")