diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h
index 6f0b3e0adc73674f56449f49e61a7264463efd5f..0f665cb26b505729fad04be94ead98ff99e79359 100644
--- a/include/linux/skbuff.h
+++ b/include/linux/skbuff.h
@@ -2847,6 +2847,18 @@ static inline int skb_linearize_cow(struct sk_buff *skb)
 	       __skb_linearize(skb) : 0;
 }
 
+static __always_inline void
+__skb_postpull_rcsum(struct sk_buff *skb, const void *start, unsigned int len,
+		     unsigned int off)
+{
+	if (skb->ip_summed == CHECKSUM_COMPLETE)
+		skb->csum = csum_block_sub(skb->csum,
+					   csum_partial(start, len, 0), off);
+	else if (skb->ip_summed == CHECKSUM_PARTIAL &&
+		 skb_checksum_start_offset(skb) < 0)
+		skb->ip_summed = CHECKSUM_NONE;
+}
+
 /**
  *	skb_postpull_rcsum - update checksum for received skb after pull
  *	@skb: buffer to update
@@ -2857,36 +2869,38 @@ static inline int skb_linearize_cow(struct sk_buff *skb)
  *	update the CHECKSUM_COMPLETE checksum, or set ip_summed to
  *	CHECKSUM_NONE so that it can be recomputed from scratch.
  */
-
 static inline void skb_postpull_rcsum(struct sk_buff *skb,
 				      const void *start, unsigned int len)
 {
-	if (skb->ip_summed == CHECKSUM_COMPLETE)
-		skb->csum = csum_sub(skb->csum, csum_partial(start, len, 0));
-	else if (skb->ip_summed == CHECKSUM_PARTIAL &&
-		 skb_checksum_start_offset(skb) < 0)
-		skb->ip_summed = CHECKSUM_NONE;
+	__skb_postpull_rcsum(skb, start, len, 0);
 }
 
-unsigned char *skb_pull_rcsum(struct sk_buff *skb, unsigned int len);
+static __always_inline void
+__skb_postpush_rcsum(struct sk_buff *skb, const void *start, unsigned int len,
+		     unsigned int off)
+{
+	if (skb->ip_summed == CHECKSUM_COMPLETE)
+		skb->csum = csum_block_add(skb->csum,
+					   csum_partial(start, len, 0), off);
+}
 
+/**
+ *	skb_postpush_rcsum - update checksum for received skb after push
+ *	@skb: buffer to update
+ *	@start: start of data after push
+ *	@len: length of data pushed
+ *
+ *	After doing a push on a received packet, you need to call this to
+ *	update the CHECKSUM_COMPLETE checksum.
+ */
 static inline void skb_postpush_rcsum(struct sk_buff *skb,
 				      const void *start, unsigned int len)
 {
-	/* For performing the reverse operation to skb_postpull_rcsum(),
-	 * we can instead of ...
-	 *
-	 *   skb->csum = csum_add(skb->csum, csum_partial(start, len, 0));
-	 *
-	 * ... just use this equivalent version here to save a few
-	 * instructions. Feeding csum of 0 in csum_partial() and later
-	 * on adding skb->csum is equivalent to feed skb->csum in the
-	 * first place.
-	 */
-	if (skb->ip_summed == CHECKSUM_COMPLETE)
-		skb->csum = csum_partial(start, len, skb->csum);
+	__skb_postpush_rcsum(skb, start, len, 0);
 }
 
+unsigned char *skb_pull_rcsum(struct sk_buff *skb, unsigned int len);
+
 /**
  *	skb_push_rcsum - push skb and update receive checksum
  *	@skb: buffer to update
diff --git a/net/core/filter.c b/net/core/filter.c
index 5708999f8a7945ec738e2043057817403e1b64ef..b5add4ef0d1d21d5a830af3dfe3dd2ba34d07976 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -1365,6 +1365,18 @@ static inline int bpf_try_make_writable(struct sk_buff *skb,
 	return err;
 }
 
+static inline void bpf_push_mac_rcsum(struct sk_buff *skb)
+{
+	if (skb_at_tc_ingress(skb))
+		skb_postpush_rcsum(skb, skb_mac_header(skb), skb->mac_len);
+}
+
+static inline void bpf_pull_mac_rcsum(struct sk_buff *skb)
+{
+	if (skb_at_tc_ingress(skb))
+		skb_postpull_rcsum(skb, skb_mac_header(skb), skb->mac_len);
+}
+
 static u64 bpf_skb_store_bytes(u64 r1, u64 r2, u64 r3, u64 r4, u64 flags)
 {
 	struct bpf_scratchpad *sp = this_cpu_ptr(&bpf_sp);
@@ -1395,7 +1407,7 @@ static u64 bpf_skb_store_bytes(u64 r1, u64 r2, u64 r3, u64 r4, u64 flags)
 		return -EFAULT;
 
 	if (flags & BPF_F_RECOMPUTE_CSUM)
-		skb_postpull_rcsum(skb, ptr, len);
+		__skb_postpull_rcsum(skb, ptr, len, offset);
 
 	memcpy(ptr, from, len);
 
@@ -1404,7 +1416,7 @@ static u64 bpf_skb_store_bytes(u64 r1, u64 r2, u64 r3, u64 r4, u64 flags)
 		skb_store_bits(skb, offset, ptr, len);
 
 	if (flags & BPF_F_RECOMPUTE_CSUM)
-		skb_postpush_rcsum(skb, ptr, len);
+		__skb_postpush_rcsum(skb, ptr, len, offset);
 	if (flags & BPF_F_INVALIDATE_HASH)
 		skb_clear_hash(skb);
 
@@ -1607,9 +1619,6 @@ static const struct bpf_func_proto bpf_csum_diff_proto = {
 
 static inline int __bpf_rx_skb(struct net_device *dev, struct sk_buff *skb)
 {
-	if (skb_at_tc_ingress(skb))
-		skb_postpush_rcsum(skb, skb_mac_header(skb), skb->mac_len);
-
 	return dev_forward_skb(dev, skb);
 }
 
@@ -1648,6 +1657,8 @@ static u64 bpf_clone_redirect(u64 r1, u64 ifindex, u64 flags, u64 r4, u64 r5)
 	if (unlikely(!skb))
 		return -ENOMEM;
 
+	bpf_push_mac_rcsum(skb);
+
 	return flags & BPF_F_INGRESS ?
 	       __bpf_rx_skb(dev, skb) : __bpf_tx_skb(dev, skb);
 }
@@ -1693,6 +1704,8 @@ int skb_do_redirect(struct sk_buff *skb)
 		return -EINVAL;
 	}
 
+	bpf_push_mac_rcsum(skb);
+
 	return ri->flags & BPF_F_INGRESS ?
 	       __bpf_rx_skb(dev, skb) : __bpf_tx_skb(dev, skb);
 }
@@ -1756,7 +1769,10 @@ static u64 bpf_skb_vlan_push(u64 r1, u64 r2, u64 vlan_tci, u64 r4, u64 r5)
 		     vlan_proto != htons(ETH_P_8021AD)))
 		vlan_proto = htons(ETH_P_8021Q);
 
+	bpf_push_mac_rcsum(skb);
 	ret = skb_vlan_push(skb, vlan_proto, vlan_tci);
+	bpf_pull_mac_rcsum(skb);
+
 	bpf_compute_data_end(skb);
 	return ret;
 }
@@ -1776,7 +1792,10 @@ static u64 bpf_skb_vlan_pop(u64 r1, u64 r2, u64 r3, u64 r4, u64 r5)
 	struct sk_buff *skb = (struct sk_buff *) (long) r1;
 	int ret;
 
+	bpf_push_mac_rcsum(skb);
 	ret = skb_vlan_pop(skb);
+	bpf_pull_mac_rcsum(skb);
+
 	bpf_compute_data_end(skb);
 	return ret;
 }