0%

文章字数:590,阅读全文大约需要2分钟

配置注解

  • @Configuration注解相当于将当前类当作一个bean注入的组件

  • @Bean将当前方法的返回值注入成为组件,beanid是方法名

  • @ComponentScan(value="com.xx.xx")扫描并注入的文件/文件包位置

  • AnnotationConfigApplicationContext(xxx.class)把指定的class当作配置文件加载一个上下文。

  • 指定扫描的包中标记有@Controller注解的类注入

    1
    2
    3
    @ComponentScan(value="com.xx.xx", includeFilters={
    @Filter(type=FilterType.ANNOTATION, classes={Controller.class})
    }, useDefaultFilter=false)

    使用excludeFilters即为排除
    useDefaultFilter即是否使用默认的扫描
    @Filter(type=FilterType.ASSIGNABLE_TYPE), classes={xx.class}指定注入的类

  • 自定义扫描器
    @Filter(type=FilterType.CUSTOM, classes={MyFilter .class}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class MyFilter implements TypeFilter{
/** MetadataReader 当前扫描类信息 MetadataReaderFactory 其他任何类信息 */
@Override
public boolean match(MetadataReader r, MetadataReaderFactory f)thros IOException{
// 获取当前类注解信息
r.getAnnotationMetadata();
// 获取当前正在扫描的类信息
r.getClassMetadata();
// 类路径信息
r.getResource();
// 类信息
f.
return false;
}
}

单实例和多实例

  • 单实例在容器初始化时就创建(@Bean创建默认是单实例),可以使用@Lazy强制懒加载,第一次使用再创建。
  • 多实例用时创建(@Scope("prototype")注解)

注入判断

1
2
3
4
5
6
7
8
9
10
11
12
public class MyCondition implements Condition{
@Override
public boolean matches(ConditionContext context, AnnotatedTypeMetadata metadata){
// 获取当前IOC容器使用的beanFactory
context.getBeanFactory();
// 获取当前环境变量(例如当前的操作系统)
Enviroment enviroment = context.getEnviroment();
// 判断当前的操作系统
enviroment.getProperty("os.name");
return false;// 是否注入
}
}

注入时使用判断过滤器

1
2
3
4
5
@Coditional(MyCondition.class)
@Bean
public Object getOb(){
...
}

import注册Bean

注入的几种方式

  • @Bean一般用于导入第三方的组件
  • 包扫描+类上的标注注解
  • @Import能够快速给容器导入组件
  • Factorybean接口实现

import注册Bean的几种方式

1
2
3
// 注册类上直接注册,id是全类名
@Import(value={xx.class})
@Configuration
  • ImportSelector是一个接口,返回需要导入容器的全类名和数组
    1
    2
    3
    4
    5
    6
    public class MySelect implements ImportSelector{
    public String[] selectImports(AnnotationMetadata annotationMetadata) {
    // 返回类的全类名
    return new String[]{"com.xx.yy", "com.xx.zz"};
    }
    }
1
2
3
// 使用选择器注入
@Import(value={MySelect .class})
@Configuration
  • ImportBeanDefinitionRegistrar接口,满足条件注入Bean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/**AnnotationMetadata 当前类注解信息
* BeanDefinitionRegistry BeanDefinition注册类,把所需要添加到容器的bean加入
*/
public void registerBeanDefinitions(AnnotationMetadata annotationMetadata, BeanDefinitionRegistry registry) {
boolean bean1 = registry.containsBeanDefintion("com.xx.yy");
boolean bean2 = registry.containsBeanDefintion("com.xx.zz");
// 如果这两个bean都存在于容器,那就创建Pig类到容器
if(bean1 && bean2) {
// 注册之前需要先对类封装一下(Spring中很多bean都这个类封装的)
RootBeanDefinition rootBeanDefinition = new BootBeanDefinition(Pig.class);
// 注册进容器
registry.registerBeanDefinition("pig", rootBeanDefinition );
}
}
1
2
3
// 使用ImportBeanDefinitionRegistrar注入
@Import(value={MyregisterBeanDefinitions.class})
@Configuration

FactoryBean接口注入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public MyFactoryBean implements FactoryBean<MyClass>{
@Override
public MyClass getObject() throws Exception {
return new MyClass();
}
@Override
public Class<?> getObjectType() {
return MyClass.class;
}
@Override
public boolean isSingleton() {
return true;
}
}
1
2
3
4
5
6
7
// 注入FactoryBean即为注入MyFactoryBean所包含的Class
// BeanFactory在获取bean时会判断,如果类型是FactoryBean,会调用其getObject方法,并返回内容。
// &getMyFactoryBean则是获取factory本身
@Bean
public MyFactoryBean getMyFactoryBean() {
return MyFactoryBean();
}

spring容器

  • beanDefinitionMap里面就是Spring的容器

文章字数:306,阅读全文大约需要1分钟

一、StackOverflowError

分析: 栈空间溢出,栈空间是线程私有的,一般用于保存方法体。所以可能是方法体太大或者太多了?

查看源码注释

1
2
3
4
5
6
7
8
9
/**
* Thrown when a stack overflow occurs because an application
* recurses too deeply.
*
* @author unascribed
* @since JDK1.0
*/
public
class StackOverflowError extends VirtualMachineError {

大致意思是:抛出这个异常是因为方法递归太深。也就是每次递归调用新方法都会将上一个方法的运行信息压入栈,当递归太深,导致数据过多,栈空间不足是就会抛出这个错误。

OutOfMemoryError

分析: 内存不足,jvm中方法保存在堆里。应该是内存不够分配新的对象导致的

源码注释

1
2
3
4
5
6
7
8
9
10
11
12
13
14
/**
* Thrown when the Java Virtual Machine cannot allocate an object
* because it is out of memory, and no more memory could be made
* available by the garbage collector.
*
* {@code OutOfMemoryError} objects may be constructed by the virtual
* machine as if {@linkplain Throwable#Throwable(String, Throwable,
* boolean, boolean) suppression were disabled and/or the stack trace was not
* writable}.
*
* @author unascribed
* @since JDK1.0
*/
public class OutOfMemoryError extends VirtualMachineError {

当内存不过创建新对象,gc也不能回收足够的空间时抛出。也就是创建太多的对象实例,或者集合一直扩容此类情况导致的。


文章字数:48,阅读全文大约需要1分钟

依赖

先引入redis

1
2
3
4
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

代码

  1. 配置监听器
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import edu.zut.ding.listener.RedisExpiredListener;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;

@Configuration
public class RedisListenerConfig{
@Bean
RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
return container;
}
}
  1. 监听消息
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import edu.zut.ding.constants.SystemConstant;
import edu.zut.ding.enums.OrderState;
import edu.zut.ding.service.OrderService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.listener.KeyExpirationEventMessageListener;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.stereotype.Component;


@Component
public class RedisKeyExpirationListener extends KeyExpirationEventMessageListener {

public RedisKeyExpirationListener(RedisMessageListenerContainer listenerContainer) {
super(listenerContainer);
}

/**
* 针对redis数据失效事件,进行数据处理
* @param message
* @param pattern
*/
@Override
public void onMessage(Message message, byte[] pattern) {
String key= message.toString();
}
}

文章字数:1283,阅读全文大约需要5分钟

StringUtils是一款字符串处理工具,这里列举了一下常用功能

  1. isEmpty(String str) 是否为空,空格字符为false
  2. isNotEmpty(String str) 是否为非空,空格字符为true
  3. isBlank(String str) 是否为空,空格字符为true
  4. isNotBlank(String str) 是否为非空,空格字符为false
  5. trim(String str)去除字符串两端的控制符,空字符串、null 返回 null
  6. trimToEmpty(String str) 去除字符串两端的控制符,空字符串、null 返回””
  7. stripToNull(String str) 去除字符串两端的空白符,空字符串、null 返回null
  8. stripToEmpty(String str) 去除字符串两端的空白符,空字符串、null 返回””
  9. strip(String str, String stripChars) 去掉str两端的在stripChars中的字符
  10. StringUtils.strip(“000000134_76539000”,”0”)=”134_76539”
  11. stripStart (String str,String stripChars) 去除str 前端在stripChars中的字符
  12. stripEnd (String str,String stripChars) 去除str 后端在stripChars中的字符
  13. equals(String str1,String str2) 比较两个字符串是否相等,如果两个均为空则认为相等
  14. indexOf(String str,char searchChar) 返回searchChar 在字符串中第一次出现的位置,如果没找到则返回 -1,如果str 为null 或者 “”,也返回-1
  15. indexOf(String str,char searchChar,int startPos) 返回字符searchChar从startPos开始在字符串str中第一次出现的位置。
  16. contains(String str,char searchChar) str中是否包含字符searchChar,str为null 或者 searchChar为null,返回false 。
  17. StringUtils.contains(“”, “”) = true
  18. StringUtils.contains(“dfg”, “”) = true
  19. containsIgnoreCase(String str,String searchStr) str中是否包含字符searchChar,不区分大小写
    1. int indexOfAny(String str, char[] searchChars) 找出字符数组searchChars中的字符第一次出现在字符串str中的位置。 如果字符数组中的字符都不在字符串中,则返回-1 ,如果字符串为null或””,则返回-1
  20. subString(String str,int start) 从start 开始,包含start 那个字符,得到字符串str 的子串,如果start为负数,则从后面开始数起。如果str 为null 或者 “” 则返回其本身
  21. subStringBefore(String str,String separator) 得到字符串separator第一次出现前的子串。不包含那个字符,如果str 为null 或者 “” 则返回其本身。
  22. subStringAfter(String str,String separator) 得到字符串separator第一次出现后的子串,不包含那个字符,如果 str 为null,或者””,则返回其本身
  23. subString(String str,int start,int end) 同上
    left(String str,int len) 得到字符串str从左边数len长度的子串,如果str 为null 或者 “”,则返回其本身,如果len小于0,则返回””
  24. right(String str,int len)得到字符串str从右边数len长度的子串
  25. mid(String str,int pos,int len) 得到字符串str从pos开始len长度的子串,pos小于0,则设为0。
  26. split(String str) 把字符串拆分成一个字符串数组,用空白符 作为分隔符,字符串为null 返回null,字符串为””,返回空数组{}
  27. split(String str,char c) 按照 char c 拆分字符串
  28. join(Object[] arrey)把数组中的元素连接成一个字符串返回
  29. join(Object[] arrey,char c) 把数组中的元素拼接成一个字符串返回,把分隔符 c 也带上
  30. deleteWhitespace(String str) 删除字符串中的所有空白符,包括转义字符
  31. removeStart(String str,String remove) 如果字符串str是以remove开始,则去掉这个开始,然后返回,否则返回原来的串
  32. removeEnd(String str,String remove) 如果字符串str是以字符串remove结尾,则去掉这个结尾,然后返回,否则返回原来的串。
  33. remove(String str,char remove) 去掉字符串str中所有包含remove的部分,然后返回
  34. replace(String str,String reql,String with) 在字符串text中用with代替repl,替换所有
  35. replaceChars(String str,char old,char new) 在字符串中 new 字符代替 old 字符
  36. replaceChars(String str, String searchChars, String replaceChars) 这个有点特别,先看下面三个例子
  37. StringUtils.replaceChars(“asssdf”,”s”,”yyy”)) = “ayyydf”
  38. StringUtils.replaceChars(“asdf”,”sd”,”y”)) = “ayf”
  39. StringUtils.replaceChars(“assssddddf”,”sd”,”y”))= “ayyyyf”

    解释:为什么会出现上面这样的结果呢?原来这个置换规则是这样的,他是拿searchChars的index,去replaceChars找相应的index然后替换掉,怎么说呢?比如说第一个例子 s 的index 是0,找到yyy相对应的index为0的字符是y。第二个例子 ‘s’ 的index是0,’d’的index是1, 字符’s’ 可以找到对应index为0的 ‘y’,d就找不到index为’1’的的字符了,所以就直接过滤掉了,听明白了吗?

  40. overlay(String str,String new,int start,int end) 用字符串new 覆盖字符串str从start 到 end 之间的串
  41. chop(String str) 去掉字符串的最后一个字符,比如/r/n
  42. repeat(String str,int repart) 重复字符串repeat次
  43. rightPad(String str,int size,String padStr) size长度的字符串,如果不够用padStr补齐
  44. leftPad(String str,int size,String padStr)同上
  45. center(String str,int size)产生一个字符串,长度等于size,str位于新串的中心
  46. swapCase(String str) 字符串中的大写转小写,小写转换为大写

文章字数:341,阅读全文大约需要1分钟

java程序中经常会有中文乱码的问题,通常都是字符编码的问题。这次系统性的了解一下这个问题

发生原因

字符编码的问题主要发生在编码和解码使用了不同的编码方式,体现在代码上就是

  1. 使用相同的字符编码解码和编码

    1
    2
    3
    4
    String test = new String("abc123阿布才");
    byte[] testByte = test.getBytes(Charset.forName("utf-8"));
    String testRes = new String(testByte, Charset.forName("utf-8"));
    System.out.println("testRes = " + testRes);

    结果:

    1
    testRes = abc123阿布才
  2. 使用不同的字符编码解码和编码

    1
    2
    3
    4
    5
    String test = new String("abc123阿布才");
    // 这里变成了gbk编码
    byte[] testByte = test.getBytes(Charset.forName("gbk"));
    String testRes = new String(testByte, Charset.forName("utf-8"));
    System.out.println("testRes = " + testRes);

    结果

    1
    testRes = abc123������

默认编码

程序内部交互一般不会出现字符编码的问题,因为编码和解码都会默认使用内部的默认编码。所以解析和编码是同一种编码类型。但是如果与其它程序交互,因为两个程序的默认编码不同则可能导致这种问题。

  1. 获取jvm当前默认字符编码
    1
    System.out.println(System.getProperty(file.encoding));

无参的String.getBytes()new String(byte[])使用的都是默认编码


文章字数:1208,阅读全文大约需要4分钟

基本概念

一、计算模型——计算图

1.1基本概念

  1. 计算图Tensorflow最基本的概念,Tensorflow中所有的计算都会转为计算图上的节点。
  2. 张量Tensor,可以理解为多维数组(0-d tensor:标量,1-d tensor:向量,2-d tensor:矩阵)表明了数据结构。
  3. flow,体现了计算模型(张量之间通过计算相互转换的过程)。Tensorflow每个计算都是计算图上的节点,节点之间的边描述了计算之间依赖关系。

1.2 基本过程

  1. 定义计算图中所有计算
  2. 执行计算

1.3使用默认计算图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 引入TensorFlow
import tensorflow as tf

# 定义两个张量
a = tf.constant([1.0, 2.0], name="a")
b = tf.constant([2.0, 3.0], name="b")

# Tensorflow会自动将定义的计算转为计算图上的节点
result = a + b

# 系统会自动维护一个默认的计算图
# 可以通过tf.get_default_graph获取当前默认计算图
#通过a.graph可以查看张量所属的计算图。因为我们没指定,所以是默认计算图
print(a.graph is tf.get_default_graph)

1.4生成新的计算图

注:不同的计算图张量和运算都不会共享

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf

# 生成新的计算图
g1 = tf.Graph()
# 设为默认
whith g1.as_default():
# 在计算图上定义变量'v',并且初始化为0
v = tf.get_variable('v', initializer=tf.zeros_initalizer()(shape=[1]))

g2 = tf.Graph()
whith g2.as_default():
# 另一个计算图上定义变量'v',初始化为1
v = tf.get_variable('v', initializer=tf.ones_initializer()(shape=[1]))

# 读取g1上的变量
whith tf.Session(graph=g1) as sess:
# 初始化图上的变量
tf.global_variables_initializer().run()
# 退出变量作用域(返回上层)并开启变量复用
whith tf.variable_scope("", reuse=True):
# 输出0
print(sess.run(tf.get_variable("v")))

# 在计算图 g2 中读取变量'v'的取值
with tf.Session(graph=g2) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope("", reuse=True):
#输出为 1
print(sess.run(tf.get_variable("v")))

1.5指定运行计算的设备

1
2
3
g = tf.Graph()
with g.device('/gpu:0'):
result = a + b

1.6常用集合

Tensorflow的图可以有效管理资源(张量、变量等)。其中自动维护的集合就是访问这些资源的有效手段

集合名 集合内容 使用场景
tf.GraphKeys.VARIABLES 所有变量 持久化TensorFlow模型
tf.GraphKeys.TRAINABLE_VARIABLES 可学习的变量(一般指神经网络的参数) 模型迅雷、生成可视化内容
tf.GraphKeys.SUMMARIES 日志生成相关的张量 TensorFlow计算可视化
tf.GraphKeys.QUEUE_RUNNERS 处理输入的QuecueRunner 输入处理
tf.GraphKeys.MOVING_AVERAGE_VARIABLES 所有计算了滑动平均值的变量 计算变量滑动平均值

二、数据类型——张量

2.1概念

  1. 张量可以简单理解为多维数组(矩阵),零阶表示标量(scalar),一节代表向量(vector),n阶代表矩阵。
  2. 张量并没有保存具体数字,保存的是的到这些数字的计算过程

2.2创建张量

1
2
3
4
5
6
7
import tensorflow as tf
# tf.constant是一个计算,这个计算的结果为一个张量,保存在a中
a = tf.constant([1.0, 2.0], name='a')
b = tf.constant([2.0, 3.0], name='b')
result - tf.add(a, b, name='add')
print(result)
# Tensor("add_2:0", shape=(2,), dtype=float32)
  1. 计算结果也是一个张量
  2. 张量的结构中有三个要素,名字(name是唯一标准)、维度(shape)、类型(type)
  3. 张量可以通过’node:src_output’的形式命名,其中node为节点名称,src_output表示当前张量来自节点的第几个输出。shape=(2,)是张量的维度信息,这个说明result是一个一维数组,且数组长度是2。第三个是类型,每个张量都有唯一的类型,Tensorflow的计算类型必须相同。

类型不匹配的例子

1
2
3
4
5
6
7
import tensorflow as tf
a = tf.constant([1, 2], name='a')
b = tf.constant([2.0, 3.0], name='b')
result = a + b
#类型不匹配报错:
#ValueError: Tensor conversion requested dtype int32 for Tensor with dtype float32:
#'Tensor("b_1:0", shape=(2,), dtype=float32)'

2.3tf.constant

用于计算得出张量的方法,原型如下

1
2
3
4
5
6
7
tf.constant(
value,
dtype=None,
shape=None,
name='Const',
verify_shape=False
)

value是必填的

value的数量必须小于shape代表的矩阵最大承受数量,少于则会用最后一个值填充

1
tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])

2.4tf.Variable

tf.Variable的运算结果也是一个张量

1
2
# 使用正态分布的方式创建一个2*3的矩阵,随机元素标准差为2
tf.Variable(tf.random_normal([2,3], stddev=2))

2.5张量的作用

  1. 对中间计算结果的引用,提高代码可读性
  2. 计算图构造完成后,张量用来计算结果

三、运行模型——会话

Tensorflow中的会话(session)是来执行定义好的运算

1
2
3
with tf.Session() as sess:
#使用创建好的会话来计算关心的结果(默认图)
sess.run()

tf.Tensor.eval函数可以计算一个张量的取值

1
2
3
4
5
6
7
8
9
10
sess = tf.Session()
with sess.as_default():
print(result.eval)

#下面代码和上面功能相同
sess = tf.Session()
#下面两个命令相同
print(sess.run(result))
print(result.evsl(session=sess))
sess.close() #书本上没有这句

tf.InteractiveSession函数可以省去将产生的会话注册为默认会花的过程。

1
2
3
4
#交互式环境下直接构建默认会话
sess = tf.InteractiveSession()
print(result.eval())
sess.close()

文章字数:1639,阅读全文大约需要6分钟

神经网络参数与Tensorflow变量

tf.Variable作用是保存和更新神经网络中的参数

1
2
#声明2 X 3的矩阵变量,元素均值为0,标准差为2的随机数
weights = tf.Variable(tf.random_normal([2,3], stddev=2))

其它生成器

函数 随机数分布 主要参数
tf.random_normal 正太分布 平均值、标准差、取值类型
tf.truncated_normal 正太分布,但如果随机出来的值偏离平均值超过2个标准差,那么值会重新随机 平均值、标准差、取值类型
tf.random_uniform 拼接分布 最小、最大取值,取值类型
tf.random_gamma Gamma分布 形状参数alpha、尺度参数beta、取值类型

使用常数初始化

函数名 功能 样例
tf.zeros 产生全0的数组 tf.zeros([2,3], int32) -> [[0,0,0], [0,0,0]]
tf.ones 产生全1的数组 tf.ones([2,3], int32) ->[[1,1,1], [1,1,1]]
tf.fill 产生指定数字的数组 tf.fill([2,3], 9) -> [[9,9,9], [9,9,9]]
tf.constant 给定值的 tf.constant([1,2,3]) -> [1,2,3]

变量使用的例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf

#初始化两个变量
w1 = tf.Variable(tf.random_normal((2,3), stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal((3,1), stddev=1, seed=1))

#输入的特征向量
x = tf.constant([[0.7, 0.9]])

#前向传播
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

sess = tf.Session()
#初始化变量
init_op = tf.global_variables_initializer()
sess.run(init_op)

print(sess.run(y))
sess.close()

#输出:[[3.957578]]

变量和张量的区别

  1. tf.Variable是一个运算,运算的结果也是一个张量,变量本质是特殊的张量
  2. tf.Variable里包含的就是一个张量
  3. tf.constant声明的是一个常量,tf.Variable是变量,后续会进行模型参数调整
  4. tf.Variable会加入到GraphKeys.VARIABLES中,trainable参数可以区分优化参数(神经网络参数)和其它参数(迭代参数)

训练神经网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf

# 声明两个参数,即神经网络的层
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))

# 给出维度可以降低出错率
# 声明占位符(入口)
x = tf.placeholder(tf.float32, shape=(1, 2), name='input')

# 使用全加,将层连接,形成网络
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

sess = tf.Session()
# 调用初始化参数
init_go = tf.global_variables_initializer()
sess.run(init_go)

# 给入口占位符值,并运算
print(sess.run(y, feed_dict={x: [[0.7, 0.9]]})) #feed_dict是个字典
# 输出: [[3.957578]]

上面的是运行一个神经网络的代码,输入feed_dict根据神经网络运算得出结果。
神经网络可以通过大量样本调整变量的值,以提高预测结果准确性。

交叉熵和反向传播

  1. 交叉熵:用于表示预测值和实际值差距的值,通过特定公式得出
  2. 反向传播:根据交叉熵反向调整神经网络参数的过程
1
2
3
4
5
6
7
8
9
10
11
12
#使用sigmoid函数将y转换为0~1之间的数值,转换后y代表预测是正样本的概率
y = tf.sigmoid(y)

# 定义损失函数来刻画预测值与真实值得差距,交叉熵
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))
+ (1-y_)*tf.log(tf.clip_by_value(1-y, 1e-10, 1.0)))
# 上面操作中的这个方法是为了将值限定在一定范围内
# tf.clip_by_value(A,min,max)
# 学习率
learning_rate = 0.001
# 定义反向传播算法来优化神经网络参数
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

Tensorflow提供的优化算法有10种,常用的有:tf.train.GradientDescentOptimizer, tf.train.AdamOptimizer, tf.train.MomentumOptimizer。定义反向传播之后,通过运行sess.run(train_step)就可以对所有在GraphKeys.TRAINABLE_VARIABLES集合中的变量来进行优化,使当前batch下损失函数更小。

完整神经网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
from numpy.random import RandomState

#定义训练数据batch的大小
batch_size = 8

#定义神经网络参数
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))

#在shape的一个维度上使用None可以方便使用不同的batch大小
x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')

#前向传播
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

#损失函数和反向传播
y = tf.sigmoid(y)
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))
+ (1-y_)*tf.log(tf.clip_by_value(1-y, 1e-10, 1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

#随机生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
#定义规则来给出样本标签,x1+x2<1的样例都认为是正样本,其他为负,
Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]

#创建一个会话来运行Tensorflow程序
with tf.Session() as sess:
init_go = tf.global_variables_initializer()
#初始化变量
sess.run(init_go)

#训练之前的参数
print('parameter w1 before train: ', sess.run(w1))
print('parameter w2 before train: ', sess.run(w2))

STEPS = 5000
for i in range(STEPS):
#每次迭代取batch_size个样本进行训练
start = (i * batch_size) % dataset_size
end = min(start+batch_size, dataset_size)

#通过训练样本训练神经网络并更新参数
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
#每隔1000步计算所有数据集上的交叉熵并输出
total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
print('After %d training_steps, cross entropy on all data is %g'%(i, total_cross_entropy))

print('parameter w1 after train: ', sess.run(w1))
print('parameter w2 after train: ', sess.run(w2))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
parameter w1 before train:  [[-0.8113182   1.4845988   0.06532937]
[-2.4427042 0.0992484 0.5912243 ]]
parameter w2 before train: [[-0.8113182 ]
[ 1.4845988 ]
[ 0.06532937]]
After 0 training_steps, cross entropy on all data is 1.89805
After 1000 training_steps, cross entropy on all data is 0.655075
After 2000 training_steps, cross entropy on all data is 0.626172
After 3000 training_steps, cross entropy on all data is 0.615096
After 4000 training_steps, cross entropy on all data is 0.610309
parameter w1 after train: [[ 0.02476983 0.56948674 1.6921941 ]
[-2.1977348 -0.23668921 1.1143895 ]]
parameter w2 after train: [[-0.45544702]
[ 0.49110925]
[-0.98110336]]

clip_by_value

  1. 可以将一个张量中的数值限制在一个范围之内。(可以避免一些运算错误:可以保证在进行log运算时,不会出现log0这样的错误或者大于1的概率)
  2. tf.clip_by_value(1-y,1e-10,1.0)

交叉熵

主要有两个表达式

  1. 二分类(预测的结果只有两种,0/1)
    image.png
  • yi 表示样本i的label,正类为1,负类为0
  • pi 表示样本i预测为正的概率
  1. 多分类
    image.png
  • M 类别的数量
  • yic 指示变量(0或1),如果该类别和样本i的类别相同就是1,否则是0;
  • pic对于观测样本i属于类别c的预测概率。

文章字数:455,阅读全文大约需要1分钟

ThreadLocal可以在多线程下实现各个线程的数据隔离

存储原理

直接看ThreadLocalget()方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}


ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

可以看出

  1. 数据是使用ThreadLocalMap存储的
  2. ThreadLocalMap是存放在线程对象上,所以可以保证线程之间的独立

再看map.getEntry(this)这句话调用的方法

1
2
3
4
5
6
7
8
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}

可以看出

  1. 存放在线程的Map通过ThreadLocal对象的threadLocalHashCode获取具体是那个对象
  2. ThreadLocal对象只是获取值的key真正的数据保存在Thread线程对象上

弱引用

存储数据的ThreadLocalMap里可以看到

1
2
3
4
5
6
7
8
9
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
  1. entry的键是被弱引用包裹的,即GC的时候如果没有强引用则会被直接清理。
  2. 即外面的ThreadLocal如果被赋值为null,即取消对象的强引用。不会因为ThreadLocalMap里面还有强引用而无法被清除
  3. 但是value还是强引用,所以如果不remove元素值还是会造成内存泄露

内存泄露的问题

因为只有键有加弱引用,可以使ThreadLocal外界无强引用时直接被GC。但是key还是强引用,所以需要手动remove

其它问题

  1. 虽然ThreadLocal是随着线程消亡,但如果使用线程池,那么就会被复用。因为线程池的原理就是复用线程

文章字数:198,阅读全文大约需要1分钟

java.lang.ThreadLocal为每个线程提供不同的变量拷贝(线程变量)

和其他变量的区别

  1. 全局变量:属于类,类保存在堆中。属于所有线程共有的区域。所以全局变量能够被所有的线程访问到

  2. 局部变量:属于方法,方法存在栈空间,是线程私有的。但是方法的局部变量只属于方法,外部无法访问。

  3. ThreadLocal:属于线程,线程全局使用。一个ThreadLocal存储一个值

基本使用

1
2
3
4
5
6
ThreadLocal<String> tl = new ThreadLocal<String>;

tl.set("xxx");
tl.get();
//在线程结束的时候最好手动清除一下,提高回收效率
tl.remove();

案例

每个线程独立连接sql

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import java.sql.Connection;  
import java.sql.DriverManager;
import java.sql.SQLException;

public class ConnectionManager {

private static ThreadLocal<Connection> connectionHolder = new ThreadLocal<Connection>() {
@Override
protected Connection initialValue() {
Connection conn = null;
try {
conn = DriverManager.getConnection(
"jdbc:mysql://localhost:3306/test", "username",
"password");
} catch (SQLException e) {
e.printStackTrace();
}
return conn;
}
};

public static Connection getConnection() {
return connectionHolder.get();
}

public static void setConnection(Connection conn) {
connectionHolder.set(conn);
}
}

文章字数:316,阅读全文大约需要1分钟

自定义线程池,设置数据

构造方法

1
2
3
4
5
6
7
ThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue)

ThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory)

ThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler)

ThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler)

构造方法参数

名称 类型 含义
corePoolSize int 核心线程池大小
maximumPoolSize int 最大线程池大小
keepAliveTime long 最大线程池空闲时间
unit TimeUnit 时间单位
workQueue BlockingQueue 线程等待队列,如ArrayBlockingQueue,有界队列,构造函数需要传入队列最大值。
threadFactory ThreadFactory 线程创建工厂
handler RejectedExecutionHandler 拒绝策略

预定义线程池

  1. FixedThreadPool:
1
2
3
4
5
public static ExecutorService newFixedThreadPool(int nThreads) {
return new ThreadPoolExecutor(nThreads, nThreads,
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>());
}
  • 核心线程数和最大线程数相同,所以是固定大小。
  • keepAliveTime对核心线程无效
  • LinkedBlockingQueue是无界阻塞队列,最大值是Integer.MAX_VALUE。如果提交速度大于处理速度,会造成队列阻塞,又因为队列无界,所以可能会内存溢出。
  1. CachedThreadPool:
1
2
3
4
5
public static ExecutorService newCachedThreadPool() {
return new ThreadPoolExecutor(0, Integer.MAX_VALUE,
60L, TimeUnit.SECONDS,
new SynchronousQueue<Runnable>());
}
  • 核心线程0,最大线程Integer.MAX_VALUE。无核心线程,最大线程几乎无限。
  • keepAliveTime = 60,即60s后空闲线程自动结束
  • workQueueSynchronousQueue,无缓冲队列。入队和出队必须同时进行。

自定义线程池案例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
public class ThreadTest {

public static void main(String[] args) throws InterruptedException, IOException {
int corePoolSize = 2;
int maximumPoolSize = 4;
long keepAliveTime = 10;
TimeUnit unit = TimeUnit.SECONDS;
BlockingQueue<Runnable> workQueue = new ArrayBlockingQueue<>(2);
ThreadFactory threadFactory = new NameTreadFactory();
RejectedExecutionHandler handler = new MyIgnorePolicy();
ThreadPoolExecutor executor = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, keepAliveTime, unit,
workQueue, threadFactory, handler);
executor.prestartAllCoreThreads(); // 预启动所有核心线程

for (int i = 1; i <= 10; i++) {
MyTask task = new MyTask(String.valueOf(i));
executor.execute(task);
}

System.in.read(); //阻塞主线程
}

static class NameTreadFactory implements ThreadFactory {

private final AtomicInteger mThreadNum = new AtomicInteger(1);

@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r, "my-thread-" + mThreadNum.getAndIncrement());
System.out.println(t.getName() + " has been created");
return t;
}
}

public static class MyIgnorePolicy implements RejectedExecutionHandler {

public void rejectedExecution(Runnable r, ThreadPoolExecutor e) {
doLog(r, e);
}

private void doLog(Runnable r, ThreadPoolExecutor e) {
// 可做日志记录等
System.err.println( r.toString() + " rejected");
// System.out.println("completedTaskCount: " + e.getCompletedTaskCount());
}
}

static class MyTask implements Runnable {
private String name;

public MyTask(String name) {
this.name = name;
}

@Override
public void run() {
try {
System.out.println(this.toString() + " is running!");
Thread.sleep(3000); //让任务执行慢点
} catch (InterruptedException e) {
e.printStackTrace();
}
}

public String getName() {
return name;
}

@Override
public String toString() {
return "MyTask [name=" + name + "]";
}
}
}