Commit 0918db11 by 魏建枢

代码提交

parent 19f3245f
package com.flink.processor.function;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.flink.achieve.doris.VectorAngleCalculationAchi.ResultOutput;
import com.flink.achieve.doris.VectorAngleCalculationAchi.ResultRecord;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
/**
* @author wjs
* @version 创建时间:2025-6-26 16:07:13
* 类说明
*/
public class KeyPointSelector extends KeyedProcessFunction<Tuple2<String, Long>, ResultRecord, ResultOutput>{
/**
*
*/
private static final long serialVersionUID = 1L;
private static final Logger logger = LoggerFactory.getLogger(KeyPointSelector.class);
// 阈值配置
private static final double DISTANCE_THRESHOLD = 15.0;
private static final double ANGLE_THRESHOLD = 10.0;
// 状态声明
private ListState<ResultRecord> pointBufferState;
private ValueState<Long> timerState;
@Override
public void open(Configuration parameters) {
// 初始化点缓冲区状态
ListStateDescriptor<ResultRecord> bufferDesc = new ListStateDescriptor<>("pointBuffer", ResultRecord.class);
pointBufferState = getRuntimeContext().getListState(bufferDesc);
// 初始化定时器状态
ValueStateDescriptor<Long> timerDesc = new ValueStateDescriptor<>("timerState", Long.class);
timerState = getRuntimeContext().getState(timerDesc);
}
@Override
public void processElement(ResultRecord record, Context ctx, Collector<ResultOutput> out) throws Exception {
// 1. 缓存当前点
pointBufferState.add(record);
// 2. 注册处理时间定时器(1分钟超时)
Long currentTimer = timerState.value();
long newTimer = ctx.timerService().currentProcessingTime() + 60000; // 1分钟
if (currentTimer == null || newTimer < currentTimer) {
ctx.timerService().registerProcessingTimeTimer(newTimer);
timerState.update(newTimer);
}
}
@Override
public void onTimer(long timestamp, OnTimerContext ctx, Collector<ResultOutput> out) throws Exception {
// 1. 从状态获取所有点并排序
List<ResultRecord> allPoints = new ArrayList<>();
for (ResultRecord point : pointBufferState.get()) {
allPoints.add(point);
}
allPoints.sort(Comparator.comparingLong(p -> p.rowNum));
// 2. 执行关键点选择算法(Python逻辑)
List<ResultRecord> keyPoints = selectKeyPoints(allPoints);
// 3. 生成坐标数组字符串
String vectorArray = generateVectorArray(keyPoints);
logger.info(">>>>>>>>KeyPointSelector id:{},eventTime:{},vectorArray:{}",ctx.getCurrentKey().f0,
ctx.getCurrentKey().f1,
vectorArray);
// 4. 输出结果
out.collect(new ResultOutput(
ctx.getCurrentKey().f0,
ctx.getCurrentKey().f1,
vectorArray
));
// 5. 清理状态
pointBufferState.clear();
timerState.clear();
}
private List<ResultRecord> selectKeyPoints(List<ResultRecord> points) {
List<ResultRecord> keyPoints = new ArrayList<>();
if (points.isEmpty()) return keyPoints;
// 1. 总是添加第一个点
keyPoints.add(points.get(0));
ResultRecord currentKeyPoint = points.get(0);
// 2. 遍历后续点
for (int i = 1; i < points.size(); i++) {
ResultRecord candidate = points.get(i);
// 3. 计算与上一个关键点的向量差
double dx = candidate.positionX - currentKeyPoint.positionX;
double dy = candidate.positionY - currentKeyPoint.positionY;
// 4. 计算向量长度(模)
double distance = Math.sqrt(dx * dx + dy * dy);
// 5. 计算角度(弧度转角度)
double angle = Math.abs(Math.atan2(dy, dx) * 180 / Math.PI);
// 6. 阈值检查(与Python逻辑一致)
if (distance > DISTANCE_THRESHOLD && angle > ANGLE_THRESHOLD) {
keyPoints.add(candidate);
currentKeyPoint = candidate; // 更新当前关键点
}
}
return keyPoints;
}
// 生成二维数组格式的坐标字符串
private String generateVectorArray(List<ResultRecord> points) {
StringBuilder sb = new StringBuilder("[");
for (int i = 0; i < points.size(); i++) {
ResultRecord p = points.get(i);
// 坐标转换(绝对坐标→百分比)
double xPercent = (p.positionX / p.resolutionX) * 100;
double yPercent = (p.positionY / p.resolutionY) * 100;
sb.append(String.format("[%.2f,%.2f]", xPercent, yPercent));
if (i < points.size() - 1) {
sb.append(",");
}
}
sb.append("]");
return sb.toString();
}
}
......@@ -6,6 +6,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
......@@ -17,6 +18,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.flink.achieve.doris.VectorAngleCalculationAchi.PointRecord;
import com.flink.util.CompareUtils;
import com.flink.util.LoadPropertiesFile;
import com.flink.vo.CollectLogToJsonSource;
import com.flink.vo.EventList;
import com.flink.vo.EventLogToJsonSource;
......@@ -79,7 +82,16 @@ public class PointRecordJoinProcessor extends CoProcessFunction<EventLogToJsonSo
Collector<PointRecord> out) {
for (EventList eventLogInfo : event.getEventList()) {
List<String> pointList = Optional.ofNullable(eventLogInfo.getR8())
.map(Properties::getR6).orElse(Collections.emptyList());
.filter(properties -> { // 过滤:只处理满足条件的Properties对象
String r4 = properties.getR4();
return StringUtils.isNotEmpty(r4) &&
CompareUtils.stringExists(event.getApp_key(),
LoadPropertiesFile.getPropertyFileValues("chainless.android.appKey"),
LoadPropertiesFile.getPropertyFileValues("chainless.ios.appKey"))
&& (r4.contains("Login") || r4.contains("CountryRegion") || r4.contains("BiometricAuthenticationScreen"));
})
.map(Properties::getR6)
.orElse(Collections.emptyList());
for(String pointStr : pointList) {
String points = cleanPointString(pointStr);
if (points.isEmpty()) continue;
......@@ -96,22 +108,11 @@ public class PointRecordJoinProcessor extends CoProcessFunction<EventLogToJsonSo
for (int i = 0; i < points.length; i++) {
String trimmed = points[i].trim();
if (!isValidPointFormat(trimmed)) continue;
String[] xy = splitPoint(trimmed);
if (xy.length != 2) continue;
try {
double x = Double.parseDouble(xy[0]);
double y = Double.parseDouble(xy[1]);
logger.info("parseAndEmitPoints params id:{},r9:{},i:{},xy0:{},xy1:{},Resolution_x:{},Resolution_y:{}",eventId,
timestamp,
i,
Double.parseDouble(xy[0].trim()),
Double.parseDouble(xy[1].trim()),
collectLog.getResolution_x(),
collectLog.getResolution_y());
out.collect(new PointRecord(eventId, timestamp, i,
x, y,
collectLog.getResolution_x(),
......
package com.flink.processor.function;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.flink.achieve.doris.VectorAngleCalculationAchi.AggregatedRecord;
import com.flink.achieve.doris.VectorAngleCalculationAchi.ResultRecord;
/**
* @author wjs
* @version 创建时间:2025-6-24 17:19:06 类说明
*/
public class VectorAggregationFunction extends KeyedProcessFunction<Tuple2<String, Long>, ResultRecord, AggregatedRecord> {
/**
*
*/
private static final long serialVersionUID = 1L;
private static final Logger logger = LoggerFactory.getLogger(VectorAggregationFunction.class);
// 状态声明
private transient ValueState<AggregationState> state;
@Override
public void open(Configuration parameters) {
// 初始化状态
ValueStateDescriptor<AggregationState> descriptor =
new ValueStateDescriptor<>("aggregationState", AggregationState.class);
state = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(
ResultRecord record,
Context ctx,
Collector<AggregatedRecord> out) throws Exception {
AggregationState currentState = state.value();
if (currentState == null) {
currentState = new AggregationState();
}
// 条件判断(对应SQL中的if(vector_m >= 5, if(angle_v >= 3, 1, 0), 1))
boolean shouldInclude = true;
if (record.vectorM >= 5) {
shouldInclude = (record.angleV >= 3);
}
// 更新状态
if (shouldInclude) {
currentState.sumX += record.vectorX;
currentState.sumY += record.vectorY;
// 更新最大angleV记录(对应SQL中的max(angle_v))
if (record.angleV > currentState.maxAngleV) {
currentState.maxAngleV = record.angleV;
currentState.maxAngleRecord = record;
}
}
// 保存状态
state.update(currentState);
// 注册定时器(窗口结束时触发输出)
long windowEnd = ctx.timestamp() + TimeUnit.MINUTES.toMillis(5); // 5分钟窗口
ctx.timerService().registerEventTimeTimer(windowEnd);
}
@Override
public void onTimer(long timestamp, OnTimerContext ctx, Collector<AggregatedRecord> out) throws Exception {
AggregationState currentState = state.value();
if (currentState != null && currentState.maxAngleRecord != null) {
ResultRecord maxRecord = currentState.maxAngleRecord;
// 构造输出记录
AggregatedRecord result = new AggregatedRecord();
result.id = maxRecord.id;
result.eventTime = maxRecord.eventTime;
result.rowNum = maxRecord.rowNum;
result.vectorX = maxRecord.vectorX;
result.vectorY = maxRecord.vectorY;
result.angleV = maxRecord.angleV;
result.vectorM = maxRecord.vectorM;
result.vectorDiffX = currentState.sumX;
result.vectorDiffY = currentState.sumY;
result.resolutionX = maxRecord.resolutionX;
result.resolutionY = maxRecord.resolutionY;
out.collect(result);
}
// 清理状态
state.clear();
}
// 聚合状态类
private static class AggregationState {
public double sumX = 0; // X向量累加和
public double sumY = 0; // Y向量累加和
public double maxAngleV = Double.MIN_VALUE; // 最大角度值
public ResultRecord maxAngleRecord; // 最大角度对应记录
}
}
......@@ -117,7 +117,7 @@ public class VectorAngleProcessor extends KeyedProcessFunction<Tuple2<String, Lo
));
}
// // 向量计算状态类
// 向量计算状态类
public static class VectorState {
public double prevPositionX;
public double prevPositionY;
......
package com.flink.processor.function;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.flink.achieve.doris.VectorAngleCalculationAchi.ResultOutput;
import com.flink.achieve.doris.VectorAngleCalculationAchi.ResultRecord;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
/**
* @author wjs
* @version 创建时间:2025-6-26 16:02:01
* 类说明
*/
public class VectorAngleStyleKeyPointSelector extends KeyedProcessFunction<Tuple2<String, Long>, ResultRecord, ResultOutput>{
/**
*
*/
private static final long serialVersionUID = 1L;
private static final Logger logger = LoggerFactory.getLogger(VectorAngleStyleKeyPointSelector.class);
// 状态声明
private ListState<ResultRecord> bufferState;
private ValueState<Long> timerState;
// 阈值
private static final double THRESHOLD_LONG = 15.0;
private static final double THRESHOLD_ANGLE = 10.0;
@Override
public void open(Configuration parameters) {
// 初始化状态:存储当前分组的所有数据
ListStateDescriptor<ResultRecord> bufferDesc = new ListStateDescriptor<>("bufferState", ResultRecord.class);
bufferState = getRuntimeContext().getListState(bufferDesc);
// 定时器状态(记录下次触发时间)
ValueStateDescriptor<Long> timerDesc = new ValueStateDescriptor<>("timerState", Long.class);
timerState = getRuntimeContext().getState(timerDesc);
}
@Override
public void processElement(ResultRecord record, Context ctx, Collector<ResultOutput> out) throws Exception {
// 1. 加入状态缓冲区
bufferState.add(record);
// 2. 注册/更新定时器(1分钟超时)
Long currentTimer = timerState.value();
long newTimer = ctx.timerService().currentProcessingTime() + 60000; // 1分钟
if (currentTimer == null || newTimer < currentTimer) {
ctx.timerService().registerProcessingTimeTimer(newTimer);
timerState.update(newTimer);
}
}
@Override
public void onTimer(long timestamp, OnTimerContext ctx, Collector<ResultOutput> out) throws Exception {
// 获取该key下所有数据
List<ResultRecord> bufferedData = new ArrayList<>();
for (ResultRecord record : bufferState.get()) {
bufferedData.add(record);
}
// 按rowNum排序
bufferedData.sort(Comparator.comparingLong(r -> r.rowNum));
// 如果数据不为空,开始关键点选择
List<ResultRecord> keyPoints = new ArrayList<>();
if (!bufferedData.isEmpty()) {
// 第一个点总是加入
keyPoints.add(bufferedData.get(0));
ResultRecord currentKeyPoint = bufferedData.get(0);
// 从第二个点开始遍历
for (int i = 1; i < bufferedData.size(); i++) {
ResultRecord currentPoint = bufferedData.get(i);
// 计算向量差 (dx, dy)
double dx = currentPoint.positionX - currentKeyPoint.positionX;
double dy = currentPoint.positionY - currentKeyPoint.positionY;
// 计算模长
double length = Math.sqrt(dx * dx + dy * dy);
// 计算角度(弧度转角度,并取绝对值)
double angle = Math.abs(Math.atan2(dy, dx) * 180 / Math.PI);
// 阈值判断
if (length > THRESHOLD_LONG && angle > THRESHOLD_ANGLE) {
keyPoints.add(currentPoint);
currentKeyPoint = currentPoint; // 更新当前关键点
}
}
}
// 转换为坐标列表(百分比)
List<double[]> coordinateList = new ArrayList<>();
for (ResultRecord point : keyPoints) {
double convertedX = round((point.positionX / point.resolutionX) * 100, 6);
double convertedY = round((point.positionY / point.resolutionY) * 100, 6);
coordinateList.add(new double[]{convertedX, convertedY});
}
// 生成结果字符串(二维数组格式)
StringBuilder resultBuilder = new StringBuilder("[");
for (int i = 0; i < coordinateList.size(); i++) {
double[] coord = coordinateList.get(i);
resultBuilder.append("[")
.append(coord[0])
.append(",")
.append(coord[1])
.append("]");
if (i < coordinateList.size() - 1) {
resultBuilder.append(",");
}
}
resultBuilder.append("]");
// 输出结果
out.collect(new ResultOutput(
ctx.getCurrentKey().f0, // id
ctx.getCurrentKey().f1, // eventTime
resultBuilder.toString()
));
// 清理状态
bufferState.clear();
timerState.clear();
}
// 四舍五入方法
private double round(double value, int places) {
if (places < 0) throw new IllegalArgumentException();
long factor = (long) Math.pow(10, places);
value = value * factor;
long tmp = Math.round(value);
return (double) tmp / factor;
}
}
package com.flink.processor.function;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.flink.achieve.doris.VectorAngleCalculationAchi.ResultOutput;
import com.flink.achieve.doris.VectorAngleCalculationAchi.ResultRecord;
/**
* @author wjs
* @version 创建时间:2025-6-25 16:03:37
* 类说明
*/
public class VectorDifferenceProcessor extends KeyedProcessFunction<Tuple2<String, Long>,ResultRecord,ResultOutput>{
/**
*
*/
private static final long serialVersionUID = 1L;
private static final Logger logger = LoggerFactory.getLogger(VectorDifferenceProcessor.class);
// 状态声明
private ListState<ResultRecord> bufferState;
private ValueState<Long> timerState;
@Override
public void open(Configuration parameters) {
// 初始化状态:存储当前分组的所有数据
ListStateDescriptor<ResultRecord> bufferDesc =
new ListStateDescriptor<>("bufferState", ResultRecord.class);
bufferState = getRuntimeContext().getListState(bufferDesc);
// 定时器状态(记录下次触发时间)
ValueStateDescriptor<Long> timerDesc =
new ValueStateDescriptor<>("timerState", Types.LONG);
timerState = getRuntimeContext().getState(timerDesc);
}
@Override
public void processElement(ResultRecord dot,Context ctx,Collector<ResultOutput> out) throws Exception {
// 1. 加入状态缓冲区
bufferState.add(dot);
// 2. 注册/更新定时器(1分钟超时)
long currentTimer = timerState.value() == null ? 0 : timerState.value();
long newTimer = ctx.timerService().currentProcessingTime() +
TimeUnit.MINUTES.toMillis(1);
if (currentTimer == 0 || newTimer < currentTimer) {
ctx.timerService().registerProcessingTimeTimer(newTimer);
timerState.update(newTimer);
}
}
@Override
public void onTimer(long timestamp,OnTimerContext ctx,Collector<ResultOutput> out) throws Exception {
// 1. 获取该key下所有数据并排序
List<ResultRecord> bufferedData = new ArrayList<>();
for (ResultRecord dot : bufferState.get()) {
bufferedData.add(dot);
}
bufferedData.sort(Comparator.comparingLong(d -> d.rowNum));
// 2. 计算分组标记(gr)
List<GrGroup> groups = calculateGrGroups(bufferedData);
// 3. 筛选每组中angle_v最大的点
List<VectorPoint> resultPoints = filterMaxAnglePoints(bufferedData, groups);
resultPoints.sort(Comparator.comparingLong(v -> v.rowNum));
// 4. 收集向量列表
List<double[]> coordinateList = new ArrayList<>();
for (VectorPoint point : resultPoints) {
// 单位转换计算 (resolution_x/y 来自原始数据)
double convertedX = round((point.vectorDiffX / point.resolutionX) * 100, 6);
double convertedY = round((point.vectorDiffY / point.resolutionY) * 100, 6);
// 创建包含两个元素的数组[x, y]
double[] coordinate = {convertedX, convertedY};
coordinateList.add(coordinate);
logger.info(">>>>>>>>准备输出要计算的数据: vectorDiffX:{},resolutionX:{},vectorDiffY:{},resolutionY:{}",point.vectorDiffX,point.resolutionX,point.vectorDiffY,point.resolutionY);
}
// 生成[[x1,y1],[x2,y2],...]格式的二维数组
// 注意:实际输出时可以直接使用coordinateList.toString(),但格式需要微调
StringBuilder resultBuilder = new StringBuilder("[");
for (int i = 0; i < coordinateList.size(); i++) {
double[] coord = coordinateList.get(i);
resultBuilder.append("[")
.append(coord[0])
.append(",")
.append(coord[1])
.append("]");
if (i < coordinateList.size() - 1) {
resultBuilder.append(",");
}
}
resultBuilder.append("]");
// 5. 准备输出
Long eventTime = ctx.getCurrentKey().f1;
// String dt = String.format("%tF", eventTime); // yyyy-MM-dd格式
logger.info(">>>>>>>>准备输出 最终结果: id:{},eventTime:{},vectorArray:{}",ctx.getCurrentKey().f0,eventTime,resultBuilder.toString());
out.collect(new ResultOutput(
ctx.getCurrentKey().f0, // id
eventTime,
resultBuilder.toString()
));
// 6. 清理状态
bufferState.clear();
timerState.clear();
}
// 计算Gr分组逻辑
private List<GrGroup> calculateGrGroups(List<ResultRecord> data) {
List<GrGroup> groups = new ArrayList<>();
int currentGr = 0;
Map<Integer, GrGroup> groupMap = new HashMap<>();
for (ResultRecord dot : data) {
// 计算分组标记 (根据vector_m和angle_v条件)
int mark = (dot.vectorM >= 10) ?
(dot.angleV >= 15 ? 1 : 0) : 1;
if (mark > 0) currentGr++;
groupMap.putIfAbsent(currentGr, new GrGroup(currentGr));
groupMap.get(currentGr).addDot(dot);
}
groups.addAll(groupMap.values());
return groups;
}
// 筛选每组中angle_v最大的点
private List<VectorPoint> filterMaxAnglePoints(List<ResultRecord> data,List<GrGroup> groups) {
List<VectorPoint> resultPoints = new ArrayList<>();
for (GrGroup group : groups) {
// 1. 计算组内累计值
double accX = 0, accY = 0;
List<VectorPoint> groupPoints = new ArrayList<>();
// 按row_num排序处理组内点
List<ResultRecord> sortedDots = new ArrayList<>(group.getDots());
sortedDots.sort(Comparator.comparingLong(d -> d.rowNum));
for (ResultRecord dot : sortedDots) {
accX += dot.vectorX;
accY += dot.vectorY;
groupPoints.add(new VectorPoint(
dot.rowNum,
dot.vectorX,
dot.vectorY,
accX,
accY,
dot.angleV,
dot.resolutionX,
dot.resolutionY
));
}
// 2. 找到最大angle_v
double maxAngle = groupPoints.stream()
.mapToDouble(v -> v.angleV)
.max()
.orElse(0.0);
// 3. 收集所有等于最大值的点
for (VectorPoint point : groupPoints) {
if (Math.abs(point.angleV - maxAngle) < 1e-6) {
resultPoints.add(point);
}
}
}
return resultPoints;
}
// 辅助方法:四舍五入
public double round(double value, int places) {
double scale = Math.pow(10, places);
return Math.round(value * scale) / scale;
}
// public static void main(String[] args) {
// double vectorDiffX = 0.9;
// double vectorDiffY = 0.8;
// round((point.vectorDiffX / point.resolutionX) * 100, 6);
// }
/**
* 分组结构
* @author wjs
*
*/
public static class GrGroup {
public final int grId;
private final List<ResultRecord> dots = new ArrayList<>();
public GrGroup(int grId) { this.grId = grId; }
public void addDot(ResultRecord dot) { dots.add(dot); }
public List<ResultRecord> getDots() { return dots; }
}
/**
* 向量点(含累计值)
* @author wjs
*
*/
public static class VectorPoint {
public final long rowNum;
public final double vectorX;
public final double vectorY;
public final double vectorDiffX; // 累计X
public final double vectorDiffY; // 累计Y
public final double angleV;
public final double resolutionX;
public final double resolutionY;
public VectorPoint(long rowNum, double vectorX, double vectorY,
double vectorDiffX, double vectorDiffY, double angleV,
double resolutionX, double resolutionY) {
this.rowNum = rowNum;
this.vectorX = vectorX;
this.vectorY = vectorY;
this.vectorDiffX = vectorDiffX;
this.vectorDiffY = vectorDiffY;
this.angleV = angleV;
this.resolutionX = resolutionX;
this.resolutionY = resolutionY;
}
}
}
package com.flink.processor.function;
import java.awt.Point;
import java.io.Serializable;
import java.util.*;
import java.time.Duration;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.flink.achieve.doris.VectorAngleCalculationAchi;
import com.flink.processor.function.VectorSimilarityProcessor.SimilarityResult;
import com.flink.processor.function.VectorSimilarityProcessor.VectorPair;
/**
* @author wjs
* @version 创建时间:2025-6-26 17:31:25
* 类说明
*/
public class VectorSimilarityProcessor extends KeyedProcessFunction<String, VectorPair, SimilarityResult>{
/**
*
*/
private static final long serialVersionUID = 1L;
private static final Logger logger = LoggerFactory.getLogger(VectorSimilarityProcessor.class);
private static final int MEMBER_DIFF = 2;
private static final int MEMBER_MAX = 15;
private static final double THRESHOLD_DISTANCE = 3;
// 状态存储向量对
private transient ListState<VectorPair> bufferState;
// @Override
// public void open(Configuration parameters) {
// ListStateDescriptor<VectorPair> descriptor =
// new ListStateDescriptor<>("vectorPairs", VectorPair.class);
// bufferState = getRuntimeContext().getListState(descriptor);
// }
@Override
public void open(Configuration conf) {
// 1. 创建 TTL 配置(24小时过期)
StateTtlConfig ttlConfig = StateTtlConfig.newBuilder(Duration.ofHours(24))
.setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite) // 仅在写入时重置过期时间
.setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired) // 过期后不可读
// .cleanupIncrementally(10, false) // 启用后台清理每次清理 10 个条目 和cleanupInRocksdbCompactFilter(1000)任选一个
.cleanupInRocksdbCompactFilter(1000)
.build();
// 2. 将 TTL 应用至状态描述符
ListStateDescriptor<VectorPair> descriptor =
new ListStateDescriptor<>("vectorBuffer", VectorPair.class);
descriptor.enableTimeToLive(ttlConfig); // 关键:绑定 TTL
// 3. 初始化状态
bufferState = getRuntimeContext().getListState(descriptor);
}
@Override
public void processElement(VectorPair pair, Context ctx, Collector<SimilarityResult> out) throws Exception {
// 1. 缓存当前向量对
bufferState.add(pair);
// 2. 注册1分钟超时定时器
ctx.timerService().registerProcessingTimeTimer(
ctx.timerService().currentProcessingTime() + 60000
);
}
@Override
public void onTimer(long timestamp, OnTimerContext ctx, Collector<SimilarityResult> out) throws Exception {
// 1. 获取所有缓存向量对
List<VectorPair> pairs = new ArrayList<>();
for (VectorPair pair : bufferState.get()) {
pairs.add(pair);
}
// 2. 执行相似度计算
List<SimilarityResult> results = calculateSimilarities(pairs);
// 3. 输出结果
for (SimilarityResult result : results) {
logger.info("VectorSimilarityProcessor 结果输入>>>>>>>>>>>>>> pairId:{},isSimilar:{},avgDistance:{} ",result.pairId,result.isSimilar,result.avgDistance);
out.collect(result);
}
// 4. 清理状态
bufferState.clear();
}
private List<SimilarityResult> calculateSimilarities(List<VectorPair> pairs) {
List<SimilarityResult> results = new ArrayList<>();
for (VectorPair pair : pairs) {
// 成员数量检查
if (Math.abs(pair.vectorA.size() - pair.vectorB.size()) > MEMBER_DIFF
|| pair.vectorA.size() > MEMBER_MAX) {
results.add(new SimilarityResult(pair.id, false, 0));
continue;
}
// 计算平均欧氏距离
double totalDistance = 0;
int minSize = Math.min(pair.vectorA.size(), pair.vectorB.size());
for (int i = 0; i < minSize; i++) {
Point a = pair.vectorA.get(i);
Point b = pair.vectorB.get(i);
totalDistance += calculateEuclideanDistance(a, b);
}
double avgDistance = totalDistance / minSize;
// 阈值判断
boolean isSimilar = avgDistance < THRESHOLD_DISTANCE;
results.add(new SimilarityResult(pair.id, isSimilar, avgDistance));
}
return results;
}
// 余弦相似度示例
// private double cosineSimilarity(List<Point> vecA, List<Point> vecB) {
// double dotProduct = 0.0;
// double normA = 0.0, normB = 0.0;
// for (int i = 0; i < vecA.size(); i++) {
// dotProduct += vecA.get(i).x * vecB.get(i).x + vecA.get(i).y * vecB.get(i).y;
// normA += Math.pow(vecA.get(i).x, 2) + Math.pow(vecA.get(i).y, 2);
// normB += Math.pow(vecB.get(i).x, 2) + Math.pow(vecB.get(i).y, 2);
// }
// return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
// }
private double calculateEuclideanDistance(Point a, Point b) {
return Math.sqrt(Math.pow(a.x - b.x, 2) + Math.pow(a.y - b.y, 2));
}
// 向量对POJO
public static class VectorPair implements Serializable {
/**
*
*/
private static final long serialVersionUID = 1L;
public String id; // 唯一标识符
public List<Point> vectorA;
public List<Point> vectorB;
public VectorPair() {} // Flink POJO要求
public VectorPair(String id, List<Point> vectorA, List<Point> vectorB) {
this.id = id;
this.vectorA = vectorA;
this.vectorB = vectorB;
}
}
// 二维点POJO
public static class Point implements Serializable {
/**
*
*/
private static final long serialVersionUID = 1L;
public double x;
public double y;
public Point() {} // Flink POJO要求
public Point(double x, double y) {
this.x = x;
this.y = y;
}
}
// 相似度计算结果
public static class SimilarityResult {
public String pairId;
public boolean isSimilar;
public double avgDistance;
public SimilarityResult() {} // Flink POJO要求
public SimilarityResult(String pairId, boolean isSimilar, double avgDistance) {
this.pairId = pairId;
this.isSimilar = isSimilar;
this.avgDistance = avgDistance;
}
}
}
......@@ -28,6 +28,7 @@ public class EventLogToJsonSource implements Serializable {
private String nick;
private List<EventList> eventList;
private Long createTime;
private String app_key;
private transient JoinKey joinKey; // 非序列化字段
......@@ -39,7 +40,7 @@ public class EventLogToJsonSource implements Serializable {
}
public EventLogToJsonSource(String id, String uniqueId, String deviceId, String cid, String phone, String nick,
List<EventList> eventList, Long createTime) {
List<EventList> eventList, Long createTime,String app_key) {
this.id = id;
this.uniqueId = uniqueId;
this.deviceId = deviceId;
......@@ -48,5 +49,6 @@ public class EventLogToJsonSource implements Serializable {
this.nick = nick;
this.eventList = eventList;
this.createTime = createTime;
this.app_key = app_key;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment