Matrix Equation 高斯消元解异或方程
题目大意:
给你一个 \(A\) 矩阵和一个 \(B\) 矩阵,求一个 \(C\) 矩阵满足,\(A\times C = B \bigodot C\)
\(Zi,j=(∑_{k=1}^NX_{i,k}Y_{k,j})\,mod\,2\)
\(D_{i,j}=X_{i,j}Y_{i,j}\)
\(Z_{i,j}\) 表示 \(×\) ,\(D_{i,j}\) 表示 \(\bigodot\)
题解:
首先模拟一下这个式子,观察发现,对于 \(C\) 这个矩阵来说,列与列之间是没什么关系的,每一列可以单独作为未知变量求解,因为要求这个方案数,所以我需要知道有多少个自由变元,自由变元就是指那些不受限制,可以任意取值的变量,又因为 \(C\) 的每一个数只有两个取值,所以对于每一个自由变元,答案都可以乘以2。
a11 a12 a13
a21 a22 a23
a31 a32 a33
b11 b12 b13
b21 b22 b23
b31 b32 b33
c11 c12 c13
c21 c22 c23
c31 c32 c33
axc = b*c
(a11-b11)*c11 + a12*c21 + a13*c31 = 0
a21*c11 + (a22-b21)*c21 + a23*c31 = 0
a31*c11 + a32*c21 + (a33-b31)*c31 = 0
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int mod = 998244353;
const int maxn = 300;
const int MAXN = 300;
//有 equ 个方程,var 个变元。增广矩阵行数为 equ, 列数为 var+1, 分别为 0 到 var
int equ,var;
int a[MAXN][MAXN]; //增广矩阵
int x[MAXN]; //解集
int free_x[MAXN];//用来存储自由变元(多解枚举自由变元可以使用)
int free_num;//自由变元的个数
//返回值为 -1 表示无解,为 0 是唯一解,否则返回自由变元个数
int Gauss(){
int max_r,col,k;
free_num = 0;
for(k=0,col = 0;k<equ&&col<var;k++,col++){// k = equ,col = var
max_r = k;
for(int i=k+1;i<equ;i++){
if(abs(a[i][col])>abs(a[max_r][col])) max_r = i;
}
if(a[max_r][col]==0){
k--;
free_x[free_num++] = col;//这个是自由变元
continue;
}
if(max_r!=k){
for(int j=col;j<var+1;j++) swap(a[k][j],a[max_r][j]);
}
for(int i=k+1;i<equ;i++){
if(a[i][col]!=0){
for(int j=col;j<var+1;j++) a[i][j]^=a[k][j];
}
}
}
//无解:化简的增广阵中存在(0,0,...,a)这样的行,且a!=0
for(int i=k;i<equ;i++){
if(a[i][col]!=0) return -1;
}
// printf("!!! k = %d col = %d\n",k,col);
if(k<var) return var - k;// 自由变元的数量
for(int i=var-1;i>=0;i--){
x[i] = a[i][var];
for(int j = i+1;j<var;j++){
x[i]^=(a[i][j]&&x[j]);
}
}
return 0;
}
int b[maxn][maxn],c[maxn][maxn];
ll binpow(ll x,ll k){
ll ans = 1;
x %= mod;
while (k){
if(k&1) ans = ans*x%mod;
k>>=1,x = x * x % mod;
}
return ans;
}
int n;
void Print(){
printf("debug:Print\n");
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
// printf("a[%d][%d]=%d\n",i,j,a[i][j]);
printf("%3d",a[i][j]);
}
printf("\n");
}
}
int main(){
ll ans = 1;
scanf("%d",&n);
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
scanf("%d",&c[i][j]);
}
}
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
scanf("%d",&b[i][j]);
}
}
for(int s=0;s<n;s++){
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
a[i][j] = c[i][j];
}
}
for(int i=0;i<n;i++) a[i][i] = (a[i][i] - b[i][s] + 2)%2;
// Print();
equ = var = n;
int res = Gauss();
if(res==-1){
ans = -1;
break;
}
// printf("s = %d ans = %lld\n",s,ans);
ans = ans * binpow(2,res) % mod;
}
if(ans==-1) printf("0\n");
else printf("%lld\n",ans);
return 0;
}