// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER


gpu.module @xevm_module {
gpu.func @store_1D_vector(%vec: vector<8xf32>,
    %source: memref<8x16x32xf32>, %offset: index) {
  vector.transfer_write %vec, %source[%offset, %offset, %offset]
    {in_bounds = [true]}
    : vector<8xf32>, memref<8x16x32xf32>
  gpu.return
}

// STORE-ND-LABEL: @store_1D_vector(
// STORE-ND-SAME:  %[[VEC:.+]]: vector<8xf32>,
// STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME:  %[[OFFSET:.+]]: index
// STORE-ND:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// STORE-ND:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
// STORE-ND:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// STORE-ND-SAME:    : memref<f32> -> index
// STORE-ND:       %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
// STORE-ND:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// STORE-ND:       %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
// STORE-ND-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME:    boundary_check = false
// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>

// STORE-SCATTER-LABEL:  @store_1D_vector(
// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf32>,
// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-SCATTER-DAG:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
// STORE-SCATTER-DAG:        %[[STEP:.+]] = vector.step
// STORE-SCATTER-COUNT2: arith.muli {{.*}} : index
// STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
// STORE-SCATTER-DAG:    %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
// STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
// STORE-SCATTER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
}

// -----
gpu.module @xevm_module {
gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
    %source: memref<8x16x32xf32>, %offset: index) {
  vector.transfer_write %vec, %source[%offset, %offset, %offset]
    {in_bounds = [true, true]}
    : vector<8x16xf32>, memref<8x16x32xf32>
  gpu.return
}

// STORE-ND-LABEL: @store_2D_vector(
// STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME:  %[[OFFSET:.+]]: index
// STORE-ND:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// STORE-ND:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// STORE-ND-SAME:    : memref<f32> -> index
// STORE-ND:       %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
// STORE-ND:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// STORE-ND:       %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
// STORE-ND-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME:    boundary_check = false
// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>

// STORE-SCATTER-LABEL:  @store_2D_vector(
// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16xf32>,
// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-SCATTER-SAME:   %[[OFFSET:.+]]: index
// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG:    %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
// STORE-SCATTER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
}

// -----
gpu.module @xevm_module {
gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
    %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
  vector.transfer_write %vec, %source[%i, %j, %k]
    {in_bounds = [true, true]}
    : vector<8x16xf32>, memref<?x?x?xf32>
  gpu.return
}

// STORE-ND-LABEL: @store_dynamic_source(
// STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
// STORE-ND-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-ND:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// STORE-ND:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// STORE-ND:       %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// STORE-ND:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// STORE-ND:       %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// STORE-ND-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>

// STORE-SCATTER-LABEL: @store_dynamic_source(
// STORE-SCATTER-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
// STORE-SCATTER-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
// STORE-SCATTER-DAG:   %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
// STORE-SCATTER-DAG:   memref.extract_strided_metadata %[[SRC]] : memref<?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index
// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG:   %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// STORE-SCATTER-DAG:   %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG:   %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
// STORE-SCATTER-DAG:   %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
}

// -----
gpu.module @xevm_module {
gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
    %source: memref<7x64xf32>, %offset: index) {
  vector.transfer_write %vec, %source[%offset, %offset]
    {in_bounds = [false, true]}
    : vector<8x16xf32>, memref<7x64xf32>
  gpu.return
}

// STORE-ND-LABEL:   @store_out_of_bounds(
// STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
// STORE-ND-SAME:  %[[OFFSET:.+]]: index
// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
// STORE-ND-SAME:    %[[SRC]]
// STORE-ND-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>

// STORE-SCATTER-LABEL:  @store_out_of_bounds(
// STORE-SCATTER:   vector.transfer_write
}

// -----
gpu.module @xevm_module {
gpu.func @no_store_transposed(%vec: vector<8x16xf32>,
    %source: memref<32x64xf32>, %offset: index) {
  vector.transfer_write %vec, %source[%offset, %offset]
    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
    in_bounds = [true, true]}
    : vector<8x16xf32>, memref<32x64xf32>
  gpu.return
}

// STORE-ND-LABEL: @no_store_transposed(
// STORE-ND:       vector.transfer_write

// STORE-SCATTER-LABEL:  @no_store_transposed(
// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16xf32>,
// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
// STORE-SCATTER-SAME:   %[[OFFSET:.+]]: index
// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG:    %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32x64xf32> -> index
// STORE-SCATTER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
}

// -----
gpu.module @xevm_module {
gpu.func @store_high_dim_vector(%vec: vector<8x16x32xf32>,
    %source: memref<16x32x64xf32>, %offset: index) {
  vector.transfer_write %vec, %source[%offset, %offset, %offset]
    {in_bounds = [true, true, true]}
    : vector<8x16x32xf32>, memref<16x32x64xf32>
  gpu.return
}

// STORE-ND-LABEL: @store_high_dim_vector(
// STORE-ND:       vector.transfer_write

// STORE-SCATTER-LABEL:  @store_high_dim_vector(
// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16x32xf32>,
// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<16x32x64xf32>
// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
// STORE-SCATTER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
// STORE-SCATTER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
// STORE-SCATTER:        %[[C2048:.+]] = arith.constant 2048 : index
// STORE-SCATTER:        %[[C64:.+]] = arith.constant 64 : index
// STORE-SCATTER-COUNT3: vector.step
// STORE-SCATTER-COUNT3: vector.shape_cast
// STORE-SCATTER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
// STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
// STORE-SCATTER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
// STORE-SCATTER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
// STORE-SCATTER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<16x32x64xf32> -> index
// STORE-SCATTER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, i64, vector<8x16x32xindex>, vector<8x16x32xi1>
}

// -----
gpu.module @xevm_module {
gpu.func @no_store_masked(%vec: vector<4xf32>,
    %source: memref<4xf32>, %offset: index) {
  %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
  vector.transfer_write %vec, %source[%offset], %mask
    {in_bounds = [true]}
    : vector<4xf32>, memref<4xf32>
  gpu.return
}

// STORE-ND-LABEL: @no_store_masked(
// STORE-ND:       vector.transfer_write

// STORE-SCATTER-LABEL:  @no_store_masked(
// STORE-SCATTER:        vector.transfer_write
}

// -----
gpu.module @xevm_module {
gpu.func @no_store_tensor(%vec: vector<8x16xf32>,
    %source: tensor<32x64xf32>, %offset: index) -> tensor<32x64xf32> {
  %0 = vector.transfer_write %vec, %source[%offset, %offset]
    {in_bounds = [true, true]}
    : vector<8x16xf32>, tensor<32x64xf32>
  gpu.return %0 : tensor<32x64xf32>
}

// STORE-ND-LABEL: @no_store_tensor(
// STORE-ND:       vector.transfer_write

// STORE-SCATTER-LABEL:  @no_store_tensor(
// STORE-SCATTER:        vector.transfer_write
}

// -----
gpu.module @xevm_module {
gpu.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,
    %source: memref<32xf32, strided<[?], offset: ?>>, %offset: index) {
  vector.transfer_write %vec, %source[%offset]
    {in_bounds = [true]}
    : vector<8xf32>, memref<32xf32, strided<[?], offset: ?>>
  gpu.return
}

// STORE-ND-LABEL: @no_store_non_unit_inner_stride(
// STORE-ND:       vector.transfer_write

// STORE-SCATTER-LABEL:  @no_store_non_unit_inner_stride(
// STORE-SCATTER:        vector.transfer_write
}

// -----
gpu.module @xevm_module {
gpu.func @no_store_unsupported_map(%vec: vector<8x16xf32>,
    %source: memref<16x32x64xf32>, %offset: index) {
  vector.transfer_write %vec, %source[%offset, %offset, %offset]
    {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
    in_bounds = [true, true]}
    : vector<8x16xf32>, memref<16x32x64xf32>
  gpu.return
}

// STORE-ND-LABEL: @no_store_unsupported_map(
// STORE-ND:       vector.transfer_write

// STORE-SCATTER-LABEL:  @no_store_unsupported_map(
// STORE-SCATTER:        vector.transfer_write
}

// -----
gpu.module @xevm_module {
gpu.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>,
    %source: memref<8x16x32xf32>, %offset: index) {
  vector.transfer_write %vec, %source[%offset, %offset, %offset]
    {in_bounds = [false]}
    : vector<8xf32>, memref<8x16x32xf32>
  gpu.return
}

// STORE-ND-LABEL: @no_store_out_of_bounds_1D_vector(
// STORE-ND:       vector.transfer_write

// STORE-SCATTER-LABEL:  @no_store_out_of_bounds_1D_vector(
// STORE-SCATTER:        vector.transfer_write
}

// -----
gpu.module @xevm_module {
gpu.func @store_to_subview(%vec: vector<8xf16>,
    %source: memref<4096x4096xf16>, %off1: index, %off2: index) {
  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1]
      : memref<4096x4096xf16>
        to memref<256x256xf16, strided<[4096, 1], offset: ?>>
  vector.transfer_write %vec, %subview[%off2, %off2]
      {in_bounds = [true]}
      : vector<8xf16>, memref<256x256xf16, strided<[4096, 1], offset: ?>>
  gpu.return
}
// STORE-ND-LABEL:  @store_to_subview(
// STORE-ND-SAME:   %[[VEC:.+]]: vector<8xf16>,
// STORE-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
// STORE-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-ND:        %[[ELEM_BYTES:.+]] = arith.constant 2 : index
// STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// STORE-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
// STORE-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// STORE-ND:        %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// STORE-ND:        %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// STORE-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// STORE-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256], strides : [1] : i64 ->
// STORE-ND-SAME:                    !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>

// STORE-SCATTER-LABEL:  @store_to_subview(
// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf16>,
// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
// STORE-SCATTER-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
// STORE-SCATTER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
// STORE-SCATTER-SAME:     : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// STORE-SCATTER:        %[[BB:.+]], %[[OFFSET:.+]], {{.*}}, {{.*}} = memref.extract_strided_metadata %[[SUBVIEW]]
// STORE-SCATTER-SAME:     : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
// STORE-SCATTER:        %[[STEP:.+]] = vector.step : vector<8xindex>
// STORE-SCATTER:        arith.muli {{.*}} : index
// STORE-SCATTER:        arith.addi %[[OFFSET]]{{.*}} : index
// STORE-SCATTER:        arith.addi {{.*}} : index
// STORE-SCATTER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
// STORE-SCATTER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
// STORE-SCATTER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]]
// STORE-SCATTER-SAME:     : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
// STORE-SCATTER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1>
}
