Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some question about Spatial Aggregation and Spatital-wise Temporal Causal Self-Attention #27

Open
pisces365 opened this issue Jul 17, 2024 · 0 comments

Comments

@pisces365
Copy link

Hi, there. Thank you for your wonderful work. I have two questions to ask.

  1. Where is the Spatial Aggregation mentioned in the paper reflected in the code?
  2. Does the Spatital-wise Temporal Causal Self-Attention refer to the following code block?
#   PlanUtransformer.py    line 327
            for pose_temporal_attn, pose_temporal_norm, spatial_attn, spatial_norm, ffn, ffn_norm in pose_attn_en:
                # 自注意力、归一化
                pose_queries = pose_queries + pose_temporal_attn(pose_queries, pose_tokens, pose_tokens, need_weights=False, attn_mask=self.attn_mask)[0]
                pose_queries = pose_temporal_norm(pose_queries)
                #b, f, h, w, c = queries.shape
                pose_queries = rearrange(pose_queries, 'b f c -> (b f) 1 c')
                queries = rearrange(queries, 'b f h w c -> (b f) (h w) c')
                # 空间注意力层、 归一化
                pose_queries = pose_queries + spatial_attn(pose_queries, queries, queries, need_weights=False, attn_mask=None)[0]
                pose_queries = spatial_norm(pose_queries)

                # 前馈网络,对空间注意力的结果进行进一步的非线性变换,增强模型的表达能力  对前馈网络的输出进行归一化。
                pose_queries = pose_queries + ffn(pose_queries)
                pose_queries = ffn_norm(pose_queries)
                pose_queries = rearrange(pose_queries, '(b f) 1 c -> b f c', b=b, f=f)
                queries = rearrange(queries, '(b f) (h w) c -> b f h w c', b=b, f=f, h=h, w=w)

#   PlanUtransformer.py    line 425
            for pose_temporal_attn, pose_temporal_norm, spatial_attn, spatial_norm, ffn, ffn_norm in pose_attn_de:
                # 自注意力、归一化
                pose_queries = pose_queries + pose_temporal_attn(pose_queries, pose_tokens, pose_tokens, need_weights=False, attn_mask=self.attn_mask)[0]
                pose_queries = pose_temporal_norm(pose_queries)
                #b, f, h, w, c = queries.shape
                pose_queries = rearrange(pose_queries, 'b f c -> (b f) 1 c')
                #queries = rearrange(queries, 'b f h w c -> (b f) (h w) c')
                queries = rearrange(queries, '(b f) c h w -> (b f) (h w) c', b=b, f=f, h=h, w=w)
                # 空间注意力、归一化
                pose_queries = pose_queries + spatial_attn(pose_queries, queries, queries, need_weights=False, attn_mask=None)[0]
                pose_queries = spatial_norm(pose_queries)
                
                pose_queries = pose_queries + ffn(pose_queries)
                pose_queries = ffn_norm(pose_queries)
                queries = rearrange(queries, '(b f) (h w) c -> (b f) c h w', b=b, f=f, h=h, w=w)
                pose_queries = rearrange(pose_queries, '(b f) 1 c -> b f c', b=b, f=f)
            pose_queries = pose_de_(pose_queries)
            pose_tokens = pose_de_(pose_tokens)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant