# The current problem list corresponds to sequence length paremeter value
# equals to 40 (seq_len = 40) and beam patameter equals to 4
# batch = 56 (num_cores)

# =============================================================================
# FWD
# =============================================================================

--reset
--dt=bf16

# -----------------------------------------------------------------------------
# Encoder part
# num_encoder_stages = 6
#   2d problems - M = batch * seq_len
#   4d problems - M = seq_len, B = batch x 16
# -----------------------------------------------------------------------------
--bia-dt=bf16 --bia_mask=2
--stag=ab --wtag=ab --dtag=ab
2240x1024:1024x1024n"Transformer_lt:FWD,Encoder_MM_1*36"

--bia-dt=undef
--stag=abcd --wtag=abdc --dtag=abcd
56x16x40x64:56x16x64x40n"Transformer_lt:FWD,Encoder_MM_4*6"

--stag=abcd --wtag=abcd --dtag=abcd
56x16x40x40:56x16x40x64n"Transformer_lt:FWD,Encoder_MM_5*6"

--bia-dt=bf16 --bia_mask=2
--stag=ab --wtag=ab --dtag=ab
2240x1024:1024x4096n"Transformer_lt:FWD,Encoder_MM_7*6"
2240x4096:4096x1024n"Transformer_lt:FWD,Encoder_MM_8*6"

# -----------------------------------------------------------------------------
# Decoder part
#   2d problems - M = batch * beam
#   4d problems - M = beam, B = batch x 16
# -----------------------------------------------------------------------------
224x1024:1024x1024n"Transformer_lt:FWD,Decoder_MM_1*36"

--bia-dt=undef
--stag=abcd --wtag=abdc --dtag=abcd
56x16x4x64:56x16x64x40n"Transformer_lt:FWD,Decoder_MM_4*6"
--stag=abcd --wtag=abcd --dtag=abcd
56x16x4x40:56x16x40x64n"Transformer_lt:FWD,Decoder_MM_5*6"

--bia-dt=bf16 --bia_mask=2
--stag=ab --wtag=ab --dtag=ab
224x1024:1024x4096n"Transformer_lt:FWD,Decoder_MM_7*6"
224x4096:4096x1024n"Transformer_lt:FWD,Decoder_MM_8*6"

--stag=ab --wtag=ba --dtag=ab
224x1024:1024x32768n"Transformer_lt:FWD,Decoder_vocabulary*1"
--bia-dt=undef
--stag=abcd --wtag=abdc --dtag=abcd
56x16x1x64:56x16x64x40n"Transformer_lt:FWD,Decoder_MM_xx40*6"
--stag=abcd --wtag=abcd --dtag=abcd
56x16x1x40:56x16x40x64n"Transformer_lt:FWD,Decoder_MM_yy40*6"

# =============================================================================
# BWD/D
# =============================================================================
# -----------------------------------------------------------------------------
# Encoder part
# num_encoder_stages = 6
#   2d problems - M = batch * seq_len
#   4d problems - M = seq_len, B = batch x 16
# -----------------------------------------------------------------------------

--stag=ab --wtag=ba --dtag=ab
2240x1024:1024x1024n"Transformer_lt:BWD_D,Encoder_MM_1*36"
--stag=abcd --wtag=abcd --dtag=abcd
56x16x40x40:56x16x40x64n"Transformer_lt:BWD_D,Encoder_MM_4_A*6"
56x16x64x40:56x16x40x40n"Transformer_lt:BWD_D,Encoder_MM_4_B*12"
--stag=abdc --wtag=abcd --dtag=abcd
56x16x40x40:56x16x40x64n"Transformer_lt:BWD_D,Encoder_MM_5_B*6"
--stag=ab --wtag=ba --dtag=ab
2240x4096:4096x1024n"Transformer_lt:BWD_D,Encoder_MM_7*6"
2240x1024:1024x4096n"Transformer_lt:BWD_D,Encoder_MM_8*6"

# -----------------------------------------------------------------------------
# Decoder part
#   2d problems - M = batch * beam
#   4d problems - M = beam, B = batch x 16
# -----------------------------------------------------------------------------
224x1024:1024x1024n"Transformer_lt:BWD_D,Decoder_MM_1*36"
--stag=abcd --wtag=abcd --dtag=abcd
56x16x4x40:56x16x40x64n"Transformer_lt:BWD_D,Decoder_MM_4_A*6"
56x16x64x4:56x16x4x40n"Transformer_lt:BWD_D,Decoder_MM_4_B*12"
--stag=abdc --wtag=abcd --dtag=abcd
56x16x40x4:56x16x4x64n"Transformer_lt:BWD_D,Decoder_MM_5_B*6"
--stag=ab --wtag=ba --dtag=ab
224x4096:4096x1024n"Transformer_lt:BWD_D,Decoder_MM_7*6"
224x1024:1024x4096n"Transformer_lt:BWD_D,Decoder_MM_8*6"
224x32768:32768x1024n"Transformer_lt:BWD_D,Decoder_vocabulary*1"
--stag=abcd --wtag=abcd --dtag=abcd
56x16x1x40:56x16x40x64n"Transformer_lt:BWD_D,Decoder_MM_xx40_A*6"
56x16x64x1:56x16x1x40n"Transformer_lt:BWD_D,Decoder_MM_xx40_B*12"
--stag=abdc --wtag=abcd --dtag=abcd
56x16x40x1:56x16x1x64n"Transformer_lt:BWD_D,Decoder_MM_yy40_B*6"

# =============================================================================
# BWD/W
# =============================================================================
--dt=bf16:bf16:f32
--stag=ba --wtag=ab --dtag=ab

# -----------------------------------------------------------------------------
# Encoder part
# num_encoder_stages = 6
#   2d problems - K = batch * seq_len
# -----------------------------------------------------------------------------
1024x2240:2240x1024n"Transformer_lt:BWD_W,Encoder_MM_1*36"
1024x2240:2240x4096n"Transformer_lt:BWD_W,Encoder_MM_7*6"
4096x2240:2240x1024n"Transformer_lt:BWD_W,Encoder_MM_8*6"

# -----------------------------------------------------------------------------
# Decoder part
#   2d problems - K = batch * beam
# -----------------------------------------------------------------------------
1024x224:224x1024n"Transformer_lt:BWD_W,Decoder_MM_1*36"
1024x224:224x4096n"Transformer_lt:BWD_W,Decoder_MM_7*6"
4096x224:224x1024n"Transformer_lt:BWD_W,Decoder_MM_8*6"
